diff --git a/.clang-tidy b/.clang-tidy index caecc0cb295..f4262fd3c13 100755 --- a/.clang-tidy +++ b/.clang-tidy @@ -115,3 +115,5 @@ CheckOptions: value: UPPER_CASE - key: readability-identifier-naming.MacroDefinitionPrefix value: MIGRAPHX_ + - key: readability-identifier-naming.ConstexprMethodIgnoredRegexp + value: 'quiet_NaN|signaling_NaN' diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 460d3fdff7d..ca7a722b9c6 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -595,4 +595,16 @@ jobs: # chmod +x codecov # ./codecov -t ${CODECOV_TOKEN} # echo "Uploaded" + misspell: + name: misspell + runs-on: ubuntu-20.04 + steps: + - name: Check out code. + uses: actions/checkout@v4 + - name: misspell + uses: reviewdog/action-misspell@v1 + with: + locale: "US" + reporter: github-pr-check + level: warning diff --git a/CHANGELOG.md b/CHANGELOG.md index 98f4f775feb..8461e3a372d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,83 @@ Full documentation for MIGraphX is available at [https://rocmdocs.amd.com/projects/AMDMIGraphX/en/latest/](https://rocmdocs.amd.com/projects/AMDMIGraphX/en/latest/). +## MIGraphX 2.11 for ROCm 6.3.0 + +### Added + +* Initial code to run on Windows +* Support for gfx120x GPU +* Support for FP8, and INT4 +* Support for the Log2 internal operator +* Support for the GCC 14 compiler +* The BitwiseAnd, Scan, SoftmaxCrossEntropyLoss, GridSample, and NegativeLogLikelihoodLoss ONNX operators +* The MatMulNBits, QuantizeLinear/DequantizeLinear, GroupQueryAttention, SkipSimplifiedLayerNormalization, and SimpliedLayerNormalization Microsoft Contrib operators +* Dymamic batch parameter support to OneHot operator +* Split-K as an optional performance improvement +* Scripts to validate ONNX models from the ONNX Model Zoo +* GPU Pooling Kernel +* --mlir flag to the migraphx-driver program to offload entire module to mlir +* Fusing split-reduce with MLIR +* Multiple outputs for the MLIR + Pointwise fusions +* Pointwise fusions with MLIR across reshape operations +* MIGRAPHX_MLIR_DUMP environment variable to dump MLIR modules to MXRs +* The 3 option to MIGRAPHX_TRACE_BENCHMARKING to print the MLIR program for improved debug output +* MIGRAPHX_ENABLE_HIPBLASLT_GEMM environment variable to call hipBlasLt libaries +* MIGRAPHX_VERIFY_DUMP_DIFF to improve the debugging of accuracy issues +* reduce_any and reduce_all options to the Reduce operation via Torch MIGraphX +* Examples for RNNT, and ControlNet + + +### Changed + +* Switched to MLIR's 3D Convolution operator. +* MLIR is now used for Attention operations by default on gfx942 and newer ASICs. +* Names and locations for VRM specific libraries have changed. +* Use random mode for benchmarking GEMMs and convolutions. +* Python version is now printed with an actual version number. + + +### Removed + +* Disabled requirements for MIOpen and rocBlas when running on Windows. +* Removed inaccuracte warning messages when using exhaustive-tune. +* Remove the hard coded path in MIGRAPHX_CXX_COMPILER allowing the compiler to be installed in different locations. + + +### Optimized + +* Improved: + * Infrastructure code to enable better Kernel fusions with all supported data types + * Subsequent model compile time by creating a cache for already performant kernels + * Use of Attention fusion with models + * Performance of the Softmax JIT kernel and of the Pooling opterator + * Tuning operations through a new 50ms delay before running the next kernel + * Performance of several convolution based models through an optimized NHWC layout + * Performance for the FP8 datatype + * GPU utilization + * Verification tools + * Debug prints + * Documentation, including gpu-driver utility documentation + * Summary section of the migrahx-driver perf command +* Reduced model compilation time +* Reordered some compiler passes to allow for more fusions +* Preloaded tiles into LDS to improve performance of pointwise transposes +* Exposed the external_data_path property in onnx_options to set the path from onnxruntime + + +### Resolved Issues + +* Fixed a bug with gfx1030 that overwrote dpp_reduce. +* Fixed a bug in 1arg dynamic reshape that created a failure. +* Fixed a bug with dot_broadcast and inner_broadcast that caused compile failures. +* Fixed a bug where some configs were failing when using exhaustive-tune. +* Fixed the ROCM Install Guide URL. +* Fixed an issue while building a whl package due to an apostrophe. +* Fixed the BERT Squad example requirements file to support different versions of Python. +* Fixed a bug that stopped the Vicuna model from compiling. +* Fixed failures with the verify option of migraphx-driver that would cause the application to exit early. + + ## MIGraphX 2.10 for ROCm 6.2.0 ### Additions diff --git a/CMakeLists.txt b/CMakeLists.txt index e4a884ffbd0..e00ad7010e1 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,7 +106,7 @@ include(ROCMSetupVersion) option(BUILD_DEV "Build for development purpose only" OFF) -rocm_setup_version(VERSION 2.11.0) +rocm_setup_version(VERSION 2.12.0) math(EXPR MIGRAPHX_SO_MAJOR_VERSION "(${PROJECT_VERSION_MAJOR} * 1000 * 1000) + (${PROJECT_VERSION_MINOR} * 1000) + ${PROJECT_VERSION_PATCH}") set(MIGRAPHX_SO_VERSION ${MIGRAPHX_SO_MAJOR_VERSION}.0) @@ -178,10 +178,8 @@ rocm_enable_clang_tidy( -bugprone-easily-swappable-parameters -bugprone-implicit-widening-of-multiplication-result -bugprone-macro-parentheses - -bugprone-multi-level-implicit-pointer-conversion -bugprone-signed-char-misuse -bugprone-unchecked-optional-access - -bugprone-unused-local-non-trivial-variable # Disable the aliased reserved identifiers -cert-dcl37-c -cert-dcl51-cpp diff --git a/Dockerfile b/Dockerfile index 53cf679bba7..5051b7b5cd8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y software-properties-common gnupg2 --no- curl -sL http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - # Add rocm repository -RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/6.2/ jammy main > /etc/apt/sources.list.d/rocm.list' +RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/6.3/ jammy main > /etc/apt/sources.list.d/rocm.list' # From docs.amd.com for installing rocm. Needed to install properly RUN sh -c "echo 'Package: *\nPin: release o=repo.radeon.com\nPin-priority: 600' > /etc/apt/preferences.d/rocm-pin-600" diff --git a/Jenkinsfile b/Jenkinsfile index 0f6725d29eb..0ce0026d454 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -71,7 +71,7 @@ def rocmtestnode(Map conf) { pre() sh "docker pull ${DOCKER_IMAGE}:${env.IMAGE_TAG}" withDockerContainer(image: "${DOCKER_IMAGE}:${env.IMAGE_TAG}", args: "--device=/dev/kfd --device=/dev/dri --group-add video --cap-add SYS_PTRACE -v=/home/jenkins/:/home/jenkins ${docker_args}") { - timeout(time: 2, unit: 'HOURS') { + timeout(time: 4, unit: 'HOURS') { body(cmake_build) } } diff --git a/codecov.yml b/codecov.yml index 03abe2daeb2..9f2569b0669 100644 --- a/codecov.yml +++ b/codecov.yml @@ -2,3 +2,4 @@ ignore: - "test/" - "src/driver" - "build/" + - "src/netron_output.cpp" diff --git a/docs/dev/data.rst b/docs/dev/data.rst index ce3e77e04b8..217a6c6b81b 100755 --- a/docs/dev/data.rst +++ b/docs/dev/data.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX internal data types + :keywords: MIGraphX, code base, contribution, developing, data types + Data types ========== diff --git a/docs/dev/dev_intro.rst b/docs/dev/dev_intro.rst index f22f1d36f17..7454821f5e0 100644 --- a/docs/dev/dev_intro.rst +++ b/docs/dev/dev_intro.rst @@ -1,3 +1,8 @@ +.. meta:: + :description: MIGraphX introduction to developing for the code base + :keywords: MIGraphX, code base, contribution, developing, introduction, developers + + Developer Introduction ====================== diff --git a/docs/dev/env_vars.rst b/docs/dev/env_vars.rst index 06e9624741c..a05b99cef64 100644 --- a/docs/dev/env_vars.rst +++ b/docs/dev/env_vars.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX internal environment variables + :keywords: MIGraphX, code base, contribution, developing, env vars, environment variables + Environment Variables ===================== @@ -112,9 +116,9 @@ Disables the ``schedule`` pass. Set to "1", "enable", "enabled", "yes", or "true" to use. Disables the ``fuse_reduce`` pass. -.. envvar:: MIGRAPHX_ENABLE_SPLIT_REDUCE -Set to "1", "enable", "enabled", "yes", or "true" to use. -Enable split_reduce. +.. envvar:: MIGRAPHX_SPLIT_REDUCE_SIZE +Set to the minimum size of a reduction to do a split reduce. Overrides what +is set in the backend. Set to -1 to disable split reduce completely. .. envvar:: MIGRAPHX_ENABLE_NHWC diff --git a/docs/dev/matchers.rst b/docs/dev/matchers.rst index 32c5b075d84..01d4ae6e35d 100644 --- a/docs/dev/matchers.rst +++ b/docs/dev/matchers.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX internal matchers + :keywords: MIGraphX, code base, contribution, developing, matchers + Matchers ======== diff --git a/docs/dev/onnx_operators.rst b/docs/dev/onnx_operators.rst index fc621b4f894..3e21f1172bb 100644 --- a/docs/dev/onnx_operators.rst +++ b/docs/dev/onnx_operators.rst @@ -1,22 +1,24 @@ +.. meta:: + :description: MIGraphX supported ONNX operators + :keywords: MIGraphX, code base, contribution, developing, ONNX operators + Supported ONNX Operators ======================== MIGraphX supports operators up to Opset 19. Latest information of ONNX -operators can be found -`here `__ +operators can be found in `the ONNX GitHub repository `_. -MIGraphX supports the following ONNX data types: BOOL, UINT8, UINT16, -UINT32, UINT64, INT8, INT16, INT32, INT64, FLOAT8, FLOAT16, FLOAT32, -DOUBLE +MIGraphX supports the following ONNX data types: BOOL, UINT8, UINT16, UINT32, UINT64, INT8, INT16, INT32, INT64, FLOAT8, FLOAT16, FLOAT32, and DOUBLE - NOTE: FP8 support is only for E4M3FNUZ, see - `here `__ + .. Note:: + + FP8 support is only for E4M3FNUZ, see `Float stored in 8 bits `_ in the ONNX documentation. See below for the support matrix of ONNX operators in MIGraphX. - NOTE: Supported Types are from ONNX specification. An operator might - support more datatypes (e.g. integer type for float operator) than - listed. + .. Note:: + + The listed supported types are taken from the ONNX specification. An operator might support other additional datatypes. Operator Support Matrix ----------------------- diff --git a/docs/dev/operators.rst b/docs/dev/operators.rst index 15691feb92f..8cf641f5767 100755 --- a/docs/dev/operators.rst +++ b/docs/dev/operators.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX internal operators + :keywords: MIGraphX, code base, contribution, developing, operators + Operators ========= diff --git a/docs/dev/pass.rst b/docs/dev/pass.rst index 4c27b706252..feada6df969 100755 --- a/docs/dev/pass.rst +++ b/docs/dev/pass.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX internal passes + :keywords: MIGraphX, code base, contribution, developing, passes + Passes ====== diff --git a/docs/dev/program.rst b/docs/dev/program.rst index 65b99343a9b..fe1ab3cfa38 100755 --- a/docs/dev/program.rst +++ b/docs/dev/program.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX program + :keywords: MIGraphX, code base, contribution, developing, program + Program ======= diff --git a/docs/dev/quantization.rst b/docs/dev/quantization.rst index aecbd63188f..16e79c8a93d 100755 --- a/docs/dev/quantization.rst +++ b/docs/dev/quantization.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX internal quantization + :keywords: MIGraphX, code base, contribution, developing, quantization + Quantization ============ diff --git a/docs/dev/targets.rst b/docs/dev/targets.rst index eb1ee223ca6..3f95688ea84 100755 --- a/docs/dev/targets.rst +++ b/docs/dev/targets.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX targets + :keywords: MIGraphX, code base, contribution, developing, targets + Targets ======= diff --git a/docs/dev/tools.rst b/docs/dev/tools.rst index 077eb5b9208..43847e399e4 100644 --- a/docs/dev/tools.rst +++ b/docs/dev/tools.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX tools + :keywords: MIGraphX, code base, contribution, developing, tooks, knobs + .. _tools: Tools diff --git a/docs/dev/triage-rocmlir.rst b/docs/dev/triage-rocmlir.rst index 63bed90455f..7f21c0a167a 100644 --- a/docs/dev/triage-rocmlir.rst +++ b/docs/dev/triage-rocmlir.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: Issue Triaging Guide for suspected issues + :keywords: MIGraphX, rocMLIR, issues, pipeline, compilation, bug, code base, kernel, contribution, developing + Issue Triaging Guide for suspected rocMLIR issue ================================================ diff --git a/docs/index.rst b/docs/index.rst index 8ec6956816a..e9ba8f40e6d 100755 --- a/docs/index.rst +++ b/docs/index.rst @@ -25,21 +25,22 @@ The MIGraphX public repository is located at `https://github.com/ROCm/AMDMIGraph .. grid-item-card:: Install - * :doc:`Installing MIGraphX <./install/installing_with_package>` + * :doc:`Installing MIGraphX with the package installer <./install/installing_with_package>` + * :doc:`Building and installing MIGraphX from source code <./install/building_migraphx>` - .. grid-item-card:: Using the MIGraphX API + .. grid-item-card:: Reference * :ref:`cpp-api-reference` * :ref:`python-api-reference` * :ref:`migraphx-driver` + * :doc:`Supported ONNX Operators <./dev/onnx_operators>` .. grid-item-card:: Contributing to the MIGraphX code base - * :doc:`Building MIGraphX <./install/building_migraphx>` * :doc:`Developing for MIGraphX <./dev/contributing-to-migraphx>` To contribute to the documentation refer to -`Contribute to ROCm documentation `_. +`Contributing to ROCm `_. Licensing information can be found on the `Licensing `_ page. diff --git a/docs/install/build_and_install_with_cmake.rst b/docs/install/build_and_install_with_cmake.rst deleted file mode 100644 index 0c34f31f642..00000000000 --- a/docs/install/build_and_install_with_cmake.rst +++ /dev/null @@ -1,62 +0,0 @@ -.. meta:: - :description: Build and install MIGraphX using CMake - :keywords: build, install, MIGraphX, AMD, ROCm, CMake - -******************************************************************** -Build and install MIGraphX using CMake -******************************************************************** - -ROCm must be installed before installing MIGraphX. See `ROCm installation for Linux `_ for information on how to install ROCm on Linux. - -.. note:: - - This method for building MIGraphX requires using ``sudo``. - - -1. Install the dependencies: - - .. code:: shell - - sudo rbuild build -d depend -B build -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') - - .. note:: - - If ``rbuild`` is not installed on your system, install it with: - - .. code:: shell - - pip3 install --prefix /usr/local https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz - -2. Create a build directory and change directory to it: - - .. code:: shell - - mkdir build - cd build - -3. Configure CMake: - - .. code:: shell - - CXX=/opt/rocm/llvm/bin/clang++ cmake .. -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') - -4. Build MIGraphX source code: - - .. code:: shell - - make -j$(nproc) - - - You can verify this using: - - .. code:: shell - - make -j$(nproc) check - - -5. Install MIGraphX libraries: - - .. code:: shell - - make install - \ No newline at end of file diff --git a/docs/install/build_and_install_with_docker.rst b/docs/install/build_and_install_with_docker.rst deleted file mode 100644 index 95057d3dfe0..00000000000 --- a/docs/install/build_and_install_with_docker.rst +++ /dev/null @@ -1,52 +0,0 @@ -.. meta:: - :description: Installing MIGraphX using Docker - :keywords: install, MIGraphX, AMD, ROCm, Docker - -******************************************************************** -Installing MIGraphX using Docker -******************************************************************** - -ROCm must be installed before installing MIGraphX. See `ROCm installation for Linux `_ for information on how to install ROCm on Linux. - -.. note:: - - Docker commands are run using ``sudo``. - -1. Build the Docker image. This command will install all the prerequisites required to install MIGraphX. Ensure that you are running this in the same directory as ``Dockerfile``. - - .. code:: shell - - sudo docker build -t migraphx . - - -2. Create and run the container. Once this command is run, you will be in the ``/code/AMDMIGraphX`` directory of a pseudo-tty. - - .. code:: shell - - sudo docker run --device='/dev/kfd' --device='/dev/dri' -v=`pwd`:/code/AMDMIGraphX -w /code/AMDMIGraphX --group-add video -it migraphx - -3. In the ``/code/AMDMIGraphX``, create a ``build`` directory, then change directory to ``/code/AMDMIGraphX/build``: - - .. code:: shell - - mkdir build - cd build - - -4. Configure CMake: - - .. code:: shell - - CXX=/opt/rocm/llvm/bin/clang++ cmake .. -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') - -4. Build the MIGraphX libraries: - - .. code:: shell - - make -j$(nproc) - -5. Install the libraries: - - .. code:: shell - - make install \ No newline at end of file diff --git a/docs/install/build_and_install_with_rbuild.rst b/docs/install/build_and_install_with_rbuild.rst deleted file mode 100644 index 3e962244940..00000000000 --- a/docs/install/build_and_install_with_rbuild.rst +++ /dev/null @@ -1,32 +0,0 @@ -.. meta:: - :description: Build and install MIGraphX using rbuild - :keywords: build, install, MIGraphX, AMD, ROCm, rbuild - -******************************************************************** -Build and install MIGraphX using rbuild -******************************************************************** - -ROCm must be installed before installing MIGraphX. See `ROCm installation for Linux `_ for information on how to install ROCm on Linux. - -.. note:: - - This method for building MIGraphX requires using ``sudo``. - -1. Install `rocm-cmake`, `pip3`, `rocblas`, and `miopen-hip`: - - .. code:: shell - - sudo apt install -y rocm-cmake python3-pip rocblas miopen-hip - -2. Install `rbuild `_: - - .. code:: shell - - pip3 install --prefix /usr/local https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz - - -3. Build MIGraphX source code: - - .. code:: shell - - sudo rbuild build -d depend -B build -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') diff --git a/docs/install/installing_with_package.rst b/docs/install/installing_with_package.rst index 0e9b67ff2a9..b2aa21ca741 100644 --- a/docs/install/installing_with_package.rst +++ b/docs/install/installing_with_package.rst @@ -10,7 +10,7 @@ ROCm must be installed before installing MIGraphX. See `ROCm installation for Li Installing MIGraphX using the package installer is sufficient for users who want to use the MIGraphX API. -If you want to develop for MIGraphX and contribute to the source code, see `Building MIGraphX `_ and `Developing for MIGraphX `_ +If you want to develop for MIGraphX and contribute to the source code, see :doc:`Building MIGraphX ` and :doc:`Developing for MIGraphX <../dev/contributing-to-migraphx>`. The package installer will install all the prerequisites needed for MIGraphX. diff --git a/docs/reference/cpp.rst b/docs/reference/cpp.rst index 57baef8bda1..1982328f7f6 100755 --- a/docs/reference/cpp.rst +++ b/docs/reference/cpp.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX C++ API reference + :keywords: MIGraphX, ROCm, C++, API, reference, development, developer + .. _cpp-api-reference: C++ API reference diff --git a/docs/reference/driver-options.rst b/docs/reference/driver-options.rst index 55012aa0fb1..e58a3752135 100644 --- a/docs/reference/driver-options.rst +++ b/docs/reference/driver-options.rst @@ -1,6 +1,6 @@ .. meta:: - :description: MIGraphX provides an optimized execution engine for deep learning neural networks - :keywords: MIGraphX, ROCm, library, API, tool + :description: MIGraphX driver options + :keywords: MIGraphX, ROCm, driver, options .. _driver-options: diff --git a/docs/reference/py.rst b/docs/reference/py.rst index c68a2df0e54..17077dc120a 100755 --- a/docs/reference/py.rst +++ b/docs/reference/py.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX Python API reference + :keywords: MIGraphX, ROCm, Python, API, reference, development, developer + .. py:module:: migraphx .. _python-api-reference: diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index 88b1383424f..528ef7cc371 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -6,6 +6,7 @@ subtrees: - caption: Installation entries: - file: install/installing_with_package + - file: install/building_migraphx - caption: Reference entries: @@ -15,10 +16,10 @@ subtrees: subtrees: - entries: - file: reference/driver-options + - file: dev/onnx_operators - caption: Developing for MIGraphX entries: - - file: install/building_migraphx - file: dev/contributing-to-migraphx subtrees: - entries: diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 3c9c8d3b29b..0c6188f2426 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.8.3 +rocm-docs-core==1.12.0 sphinx-collapse diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 536baca0d0a..9b183b5f6e6 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -116,7 +116,7 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core==1.8.3 +rocm-docs-core==1.12.0 # via -r requirements.in smmap==5.0.1 # via gitdb diff --git a/examples/diffusion/python_stable_diffusion_3/README.md b/examples/diffusion/python_stable_diffusion_3/README.md new file mode 100644 index 00000000000..62219692bc8 --- /dev/null +++ b/examples/diffusion/python_stable_diffusion_3/README.md @@ -0,0 +1,67 @@ +# Stable Diffusion 3 + +This version was tested with [rocm 6.2](https://github.com/ROCmSoftwarePlatform/AMDMIGraphX/tree/rocm-6.2.0) revision. + +## Console application + +To run the console application, follow these steps below. + +Setup python environment + +```bash +# this will require the python venv to installed (e.g. apt install python3.8-venv) +python3 -m venv sd_venv +. sd_venv/bin/activate +``` + +Install dependencies + +```bash +pip install -r torch_requirements.txt +pip install -r requirements.txt +``` + +Use MIGraphX Python Module + +```bash +export PYTHONPATH=/opt/rocm/lib:$PYTHONPATH +``` + +Get models: + +Make sure you have permission to download and use stabilityai/stable-diffusion-3. +```bash +huggingface-cli login +``` + +Export the models to onnx. +Currently, optimum does not have the changes required in their latest release. Please install from their development branch instead. +```bash +python -m pip install optimum[onnxruntime]@git+https://github.com/huggingface/optimum.git +``` + +Once optimum is built, use the following command to export the models: +```bash +optimum-cli export onnx --model stabilityai/stable-diffusion-3-medium-diffusers models/sd3 +``` + +Run the text-to-image script with the following example prompt and seed (optionally, you can change the batch size / number of images generated for that prompt) + +```bash +python txt2img.py --prompt "a photograph of an astronaut riding a horse" --steps 50 --output astro_horse.jpg +``` +> [!NOTE] +> The first run will compile the models and cache them to make subsequent runs faster. New batch sizes will result in the models re-compiling.* + +The result should look like this: + +![example_output.jpg](./example_output.jpg) + +## Lower Memory Usage Pipeline +The entire pipeline is memory intensive, even when quantizing to fp16. The T5XXL encoder can be disabled alongside fp16 quantization to reduce total GPU memory usage to under 16G. + +There will be a slight accuracy penalty when disabling T5XXL. +```bash +python txt2img.py --prompt "a photograph of an astronaut riding a horse" --steps 50 --skip-t5 --fp16=all --output astro_horse.jpg +``` + diff --git a/examples/diffusion/python_stable_diffusion_3/example_output.jpg b/examples/diffusion/python_stable_diffusion_3/example_output.jpg new file mode 100644 index 00000000000..ec1eb12f1b1 Binary files /dev/null and b/examples/diffusion/python_stable_diffusion_3/example_output.jpg differ diff --git a/examples/diffusion/python_stable_diffusion_3/other_impls.py b/examples/diffusion/python_stable_diffusion_3/other_impls.py new file mode 100644 index 00000000000..5322c4b4231 --- /dev/null +++ b/examples/diffusion/python_stable_diffusion_3/other_impls.py @@ -0,0 +1,516 @@ +# MIT License + +# Copyright (c) 2024 Stability AI + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Some code in `other_impls` originates from HuggingFace and is subject to [the HuggingFace Transformers Apache2 License](https://github.com/huggingface/transformers/blob/main/LICENSE) +### This file contains impls for underlying related models (CLIP, T5, etc) + +import torch, math +from torch import nn +from transformers import CLIPTokenizer, T5TokenizerFast + + +################################################################################################# +### Core/Utility +################################################################################################# + + +def attention(q, k, v, heads, mask=None): + """Convenience wrapper around a basic attention operation""" + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v)) + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + return out.transpose(1, 2).reshape(b, -1, heads * dim_head) + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks""" + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, dtype=None, device=None): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device) + self.act = act_layer + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +################################################################################################# +### CLIP +################################################################################################# + + +class CLIPAttention(torch.nn.Module): + def __init__(self, embed_dim, heads, dtype, device): + super().__init__() + self.heads = heads + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + + def forward(self, x, mask=None): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + out = attention(q, k, v, self.heads, mask) + return self.out_proj(out) + + +ACTIVATIONS = { + "quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), + "gelu": torch.nn.functional.gelu, +} + +class CLIPLayer(torch.nn.Module): + def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): + super().__init__() + self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) + self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + #self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) + self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device) + + def forward(self, x, mask=None): + x += self.self_attn(self.layer_norm1(x), mask) + x += self.mlp(self.layer_norm2(x)) + return x + + +class CLIPEncoder(torch.nn.Module): + def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): + super().__init__() + self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)]) + + def forward(self, x, mask=None, intermediate_output=None): + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layers) + intermediate_output + intermediate = None + for i, l in enumerate(self.layers): + x = l(x, mask) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + + +class CLIPEmbeddings(torch.nn.Module): + def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None): + super().__init__() + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device) + self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens): + return self.token_embedding(input_tokens) + self.position_embedding.weight + + +class CLIPTextModel_(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + num_layers = config_dict["num_hidden_layers"] + embed_dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + intermediate_size = config_dict["intermediate_size"] + intermediate_activation = config_dict["hidden_act"] + super().__init__() + self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) + self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) + self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True): + x = self.embeddings(input_tokens) + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output) + x = self.final_layer_norm(x) + if i is not None and final_layer_norm_intermediate: + i = self.final_layer_norm(i) + pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),] + return x, i, pooled_output + + +class CLIPTextModel(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_hidden_layers"] + self.text_model = CLIPTextModel_(config_dict, dtype, device) + embed_dim = config_dict["hidden_size"] + self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) + self.text_projection.weight.copy_(torch.eye(embed_dim)) + self.dtype = dtype + + def get_input_embeddings(self): + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, embeddings): + self.text_model.embeddings.token_embedding = embeddings + + def forward(self, *args, **kwargs): + x = self.text_model(*args, **kwargs) + out = self.text_projection(x[2]) + return (x[0], x[1], out, x[2]) + + +class SDTokenizer: + def __init__(self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None): + self.tokenizer = tokenizer + self.max_length = max_length + self.min_length = min_length + empty = self.tokenizer('')["input_ids"] + if has_start_token: + self.tokens_start = 1 + self.start_token = empty[0] + self.end_token = empty[1] + else: + self.tokens_start = 0 + self.start_token = None + self.end_token = empty[0] + self.pad_with_end = pad_with_end + self.pad_to_max_length = pad_to_max_length + vocab = self.tokenizer.get_vocab() + self.inv_vocab = {v: k for k, v in vocab.items()} + self.max_word_length = 8 + + + def tokenize_with_weights(self, text:str): + """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" + if self.pad_with_end: + pad_token = self.end_token + else: + pad_token = 0 + batch = [] + if self.start_token is not None: + batch.append((self.start_token, 1.0)) + to_tokenize = text.replace("\n", " ").split(' ') + to_tokenize = [x for x in to_tokenize if x != ""] + for word in to_tokenize: + batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]]) + batch.append((self.end_token, 1.0)) + if self.pad_to_max_length: + batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) + if self.min_length is not None and len(batch) < self.min_length: + batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) + return [batch] + + +class SDXLClipGTokenizer(SDTokenizer): + def __init__(self, tokenizer): + super().__init__(pad_with_end=False, tokenizer=tokenizer) + + +class SD3Tokenizer: + def __init__(self): + clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) + self.clip_g = SDXLClipGTokenizer(clip_tokenizer) + self.t5xxl = T5XXLTokenizer() + + def tokenize_with_weights(self, text:str): + out = {} + out["g"] = self.clip_g.tokenize_with_weights(text) + out["l"] = self.clip_l.tokenize_with_weights(text) + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text) + return out + + +class ClipTokenWeightEncoder: + def encode_token_weights(self, token_weight_pairs): + tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + out, pooled = self([tokens]) + if pooled is not None: + first_pooled = pooled[0:1].cpu() + else: + first_pooled = pooled + output = [out[0:1]] + return torch.cat(output, dim=-2).cpu(), first_pooled + + +class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + LAYERS = ["last", "pooled", "hidden"] + def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, return_projected_pooled=True): + super().__init__() + assert layer in self.LAYERS + self.transformer = model_class(textmodel_json_config, dtype, device) + self.num_layers = self.transformer.num_layers + self.max_length = max_length + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + self.layer = layer + self.layer_idx = None + self.special_tokens = special_tokens + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.layer_norm_hidden_state = layer_norm_hidden_state + self.return_projected_pooled = return_projected_pooled + if layer == "hidden": + assert layer_idx is not None + assert abs(layer_idx) < self.num_layers + self.set_clip_options({"layer": layer_idx}) + self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) + + def set_clip_options(self, options): + layer_idx = options.get("layer", self.layer_idx) + self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) + if layer_idx is None or abs(layer_idx) > self.num_layers: + self.layer = "last" + else: + self.layer = "hidden" + self.layer_idx = layer_idx + + def forward(self, tokens): + backup_embeds = self.transformer.get_input_embeddings() + device = backup_embeds.weight.device + tokens = torch.LongTensor(tokens).to(device) + outputs = self.transformer(tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) + self.transformer.set_input_embeddings(backup_embeds) + if self.layer == "last": + z = outputs[0] + else: + z = outputs[1] + pooled_output = None + if len(outputs) >= 3: + if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None: + pooled_output = outputs[3].float() + elif outputs[2] is not None: + pooled_output = outputs[2].float() + return z.float(), pooled_output + + +class SDXLClipG(SDClipModel): + """Wraps the CLIP-G model into the SD-CLIP-Model interface""" + def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None): + if layer == "penultimate": + layer="hidden" + layer_idx=-2 + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) + + +class T5XXLModel(SDClipModel): + """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience""" + def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5) + + +################################################################################################# +### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl +################################################################################################# + + +class T5XXLTokenizer(SDTokenizer): + """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + def __init__(self): + super().__init__(pad_with_end=False, tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) + + +class T5LayerNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) + self.variance_epsilon = eps + + def forward(self, x): + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + return self.weight.to(device=x.device, dtype=x.dtype) * x + + +class T5DenseGatedActDense(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) + + def forward(self, x): + hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") + hidden_linear = self.wi_1(x) + x = hidden_gelu * hidden_linear + x = self.wo(x) + return x + + +class T5LayerFF(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x): + forwarded_states = self.layer_norm(x) + forwarded_states = self.DenseReluDense(forwarded_states) + x += forwarded_states + return x + + +class T5Attention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device) + self.num_heads = num_heads + self.relative_attention_bias = None + if relative_attention_bias: + self.relative_attention_num_buckets = 32 + self.relative_attention_max_distance = 128 + self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min(relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)) + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=True, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward(self, x, past_bias=None): + q = self.q(x) + k = self.k(x) + v = self.v(x) + if self.relative_attention_bias is not None: + past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device) + if past_bias is not None: + mask = past_bias + out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask) + return self.o(out), past_bias + + +class T5LayerSelfAttention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x, past_bias=None): + output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias) + x += output + return x, past_bias + + +class T5Block(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + self.layer = torch.nn.ModuleList() + self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device)) + self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device)) + + def forward(self, x, past_bias=None): + x, past_bias = self.layer[0](x, past_bias) + x = self.layer[-1](x) + return x, past_bias + + +class T5Stack(torch.nn.Module): + def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device): + super().__init__() + self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) + self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)]) + self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True): + intermediate = None + x = self.embed_tokens(input_ids) + past_bias = None + for i, l in enumerate(self.block): + x, past_bias = l(x, past_bias) + if i == intermediate_output: + intermediate = x.clone() + x = self.final_layer_norm(x) + if intermediate is not None and final_layer_norm_intermediate: + intermediate = self.final_layer_norm(intermediate) + return x, intermediate + + +class T5(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_layers"] + self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device) + self.dtype = dtype + + def get_input_embeddings(self): + return self.encoder.embed_tokens + + def set_input_embeddings(self, embeddings): + self.encoder.embed_tokens = embeddings + + def forward(self, *args, **kwargs): + return self.encoder(*args, **kwargs) + diff --git a/examples/diffusion/python_stable_diffusion_3/requirements.txt b/examples/diffusion/python_stable_diffusion_3/requirements.txt new file mode 100644 index 00000000000..678bd7d749e --- /dev/null +++ b/examples/diffusion/python_stable_diffusion_3/requirements.txt @@ -0,0 +1,33 @@ +##################################################################################### +# The MIT License (MIT) +# +# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +##################################################################################### + +diffusers==0.30.3 +einops==0.8.0 +onnx==1.17.0 +protobuf==5.28.3 +transformers==4.46.0 +tiktoken==0.8.0 +sentencepiece==0.2.0 +--extra-index-url https://test.pypi.org/simple +hip-python-as-cuda \ No newline at end of file diff --git a/examples/diffusion/python_stable_diffusion_3/torch_requirements.txt b/examples/diffusion/python_stable_diffusion_3/torch_requirements.txt new file mode 100644 index 00000000000..bbd3939d39d --- /dev/null +++ b/examples/diffusion/python_stable_diffusion_3/torch_requirements.txt @@ -0,0 +1,25 @@ +##################################################################################### +# The MIT License (MIT) +# +# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +##################################################################################### +--index-url https://download.pytorch.org/whl/rocm6.2/ +torch diff --git a/examples/diffusion/python_stable_diffusion_3/txt2img.py b/examples/diffusion/python_stable_diffusion_3/txt2img.py new file mode 100644 index 00000000000..995f68f5d22 --- /dev/null +++ b/examples/diffusion/python_stable_diffusion_3/txt2img.py @@ -0,0 +1,618 @@ +# The MIT License (MIT) +# +# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024 Stability AI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the 'Software'), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from argparse import ArgumentParser +from diffusers import FlowMatchEulerDiscreteScheduler + +from other_impls import SD3Tokenizer + +from PIL import Image + +import migraphx as mgx +import os +import sys +import torch +import time +from functools import wraps + +from hip import hip +from collections import namedtuple +HipEventPair = namedtuple('HipEventPair', ['start', 'end']) + + +# measurement helper +def measure(fn): + @wraps(fn) + def measure_ms(*args, **kwargs): + start_time = time.perf_counter_ns() + result = fn(*args, **kwargs) + end_time = time.perf_counter_ns() + print( + f"Elapsed time for {fn.__name__}: {(end_time - start_time) * 1e-6:.4f} ms\n" + ) + return result + + return measure_ms + + +def get_args(): + parser = ArgumentParser() + # Model compile + parser.add_argument( + "--onnx-model-path", + type=str, + default="models/sd3", + help="Path to onnx model files.", + ) + + parser.add_argument( + "--compiled-model-path", + type=str, + default=None, + help= + "Path to compiled mxr model files. If not set, it will be saved next to the onnx model.", + ) + + parser.add_argument( + "--fp16", + choices=["all", "vae", "clip", "mmdit"], + nargs="+", + help="Quantize models with fp16 precision.", + ) + + parser.add_argument( + "--force-compile", + action="store_true", + default=False, + help="Ignore existing .mxr files and override them", + ) + + parser.add_argument( + "--exhaustive-tune", + action="store_true", + default=False, + help="Perform exhaustive tuning when compiling onnx models", + ) + + parser.add_argument( + "--skip-t5", + action="store_true", + default=False, + help= + "Skip the third text encoder. Small accuracy penalty but large memory savings." + ) + + # Runtime + parser.add_argument( + "-s", + "--seed", + type=int, + default=42, + help="Random seed", + ) + + parser.add_argument( + "-t", + "--steps", + type=int, + default=50, + help="Number of steps", + ) + + parser.add_argument("-b", + "--batch", + type=int, + default=1, + help="Batch count or number of images to produce") + + parser.add_argument( + "-p", + "--prompt", + type=str, + required=True, + help="Prompt", + ) + + parser.add_argument( + "-n", + "--negative-prompt", + type=str, + default="", + help="Negative prompt", + ) + + parser.add_argument( + "--scale", + type=float, + default=5.0, + help="Guidance scale", + ) + + parser.add_argument( + "-o", + "--output", + type=str, + default=None, + help="Output name", + ) + return parser.parse_args() + + +mgx_to_torch_dtype_dict = { + "bool_type": torch.bool, + "uint8_type": torch.uint8, + "int8_type": torch.int8, + "int16_type": torch.int16, + "int32_type": torch.int32, + "int64_type": torch.int64, + "float_type": torch.float32, + "double_type": torch.float64, + "half_type": torch.float16, +} + +torch_to_mgx_dtype_dict = { + value: key + for (key, value) in mgx_to_torch_dtype_dict.items() +} + + +def tensor_to_arg(tensor): + return mgx.argument_from_pointer( + mgx.shape( + **{ + "type": torch_to_mgx_dtype_dict[tensor.dtype], + "lens": list(tensor.size()), + "strides": list(tensor.stride()) + }), tensor.data_ptr()) + + +def tensors_to_args(tensors): + return {name: tensor_to_arg(tensor) for name, tensor in tensors.items()} + + +def get_output_name(idx): + return f"main:#output_{idx}" + + +def copy_tensor_sync(tensor, data): + tensor.copy_(data) + torch.cuda.synchronize() + + +def run_model_sync(model, args): + model.run(args) + mgx.gpu_sync() + + +def allocate_torch_tensors(model): + input_shapes = model.get_parameter_shapes() + data_mapping = { + name: torch.zeros(shape.lens()).to( + mgx_to_torch_dtype_dict[shape.type_string()]).to(device="cuda") + for name, shape in input_shapes.items() + } + return data_mapping + + +class StableDiffusionMGX(): + def __init__(self, + onnx_model_path, + compiled_model_path, + fp16, + batch, + force_compile=False, + exhaustive_tune=False, + skip_t5=False): + + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + subfolder="scheduler") + + self.tokenizer = SD3Tokenizer() + self.device = "cuda" + self.skip_t5 = skip_t5 + + if fp16 is None: + fp16 = [] + elif "all" in fp16: + fp16 = ["vae", "clip", "mmdit"] + + self.batch = batch + + print("Load models...") + self.models = { + "vae": + StableDiffusionMGX.load_mgx_model( + "vae_decoder", {"latent_sample": [self.batch, 16, 128, 128]}, + onnx_model_path, + compiled_model_path=compiled_model_path, + use_fp16="vae" in fp16, + force_compile=force_compile, + exhaustive_tune=exhaustive_tune, + offload_copy=False, + batch=self.batch), + "clip-l": + StableDiffusionMGX.load_mgx_model( + "text_encoder", {"input_ids": [1, 77]}, + onnx_model_path, + compiled_model_path=compiled_model_path, + use_fp16="clip" in fp16, + force_compile=force_compile, + exhaustive_tune=exhaustive_tune, + offload_copy=False), + "clip-g": + StableDiffusionMGX.load_mgx_model( + "text_encoder_2", {"input_ids": [1, 77]}, + onnx_model_path, + compiled_model_path=compiled_model_path, + use_fp16="clip" in fp16, + force_compile=force_compile, + exhaustive_tune=exhaustive_tune, + offload_copy=False), + "mmdit": + StableDiffusionMGX.load_mgx_model( + "transformer", { + "hidden_states": [2 * self.batch, 16, 128, 128], + "timestep": [2 * self.batch], + "encoder_hidden_states": [2 * self.batch, 154, 4096], + "pooled_projections": [2 * self.batch, 2048], + }, + onnx_model_path, + compiled_model_path=compiled_model_path, + use_fp16="mmdit" in fp16, + force_compile=force_compile, + exhaustive_tune=exhaustive_tune, + offload_copy=False, + batch=self.batch) + } + + self.tensors = { + "clip-g": allocate_torch_tensors(self.models["clip-g"]), + "clip-l": allocate_torch_tensors(self.models["clip-l"]), + # "t5xxl": allocate_torch_tensors(self.models["t5xxl"]), + "mmdit": allocate_torch_tensors(self.models["mmdit"]), + "vae": allocate_torch_tensors(self.models["vae"]), + } + + self.model_args = { + "clip-g": tensors_to_args(self.tensors['clip-g']), + "clip-l": tensors_to_args(self.tensors['clip-l']), + # "t5xxl": tensors_to_args(self.tensors['t5xxl']), + "mmdit": tensors_to_args(self.tensors['mmdit']), + "vae": tensors_to_args(self.tensors['vae']), + } + + if not self.skip_t5: + self.models["t5xxl"] = StableDiffusionMGX.load_mgx_model( + "text_encoder_3", {"input_ids": [1, 77]}, + onnx_model_path, + compiled_model_path=compiled_model_path, + use_fp16="clip" in fp16, + force_compile=force_compile, + exhaustive_tune=exhaustive_tune, + offload_copy=False) + self.tensors["t5xxl"] = allocate_torch_tensors( + self.models["t5xxl"]) + self.model_args["t5xxl"] = tensors_to_args(self.tensors['t5xxl']) + + self.events = { + "warmup": + HipEventPair(start=hip.hipEventCreate()[1], + end=hip.hipEventCreate()[1]), + "run": + HipEventPair(start=hip.hipEventCreate()[1], + end=hip.hipEventCreate()[1]), + "clip": + HipEventPair(start=hip.hipEventCreate()[1], + end=hip.hipEventCreate()[1]), + "denoise": + HipEventPair(start=hip.hipEventCreate()[1], + end=hip.hipEventCreate()[1]), + "decode": + HipEventPair(start=hip.hipEventCreate()[1], + end=hip.hipEventCreate()[1]), + } + + self.stream = hip.hipStreamCreate()[1] + + def cleanup(self): + for event in self.events.values(): + hip.hipEventDestroy(event.start) + hip.hipEventDestroy(event.end) + hip.hipStreamDestroy(self.stream) + + def profile_start(self, name): + if name in self.events: + hip.hipEventRecord(self.events[name].start, None) + + def profile_end(self, name): + if name in self.events: + hip.hipEventRecord(self.events[name].end, None) + + @measure + @torch.no_grad() + def run(self, prompt, negative_prompt, steps, seed, scale): + torch.cuda.synchronize() + self.profile_start("run") + + print("Tokenizing prompts...") + prompt_tokens = self.tokenize(prompt) + neg_prompt_tokens = self.tokenize(negative_prompt) + + print("Creating text embeddings...") + self.profile_start("clip") + prompt_embeddings = self.get_embeddings(prompt_tokens) + neg_prompt_embeddings = self.get_embeddings(neg_prompt_tokens) + self.profile_end("clip") + + # fix height and width for now + # TODO: check for valid height/width combinations + # and make them member variables + height = 1024 + width = 1024 + latent = torch.empty(1, 16, height // 8, width // 8, device="cpu") + + generator = torch.manual_seed(seed) + latent = torch.randn(latent.size(), + dtype=torch.float32, + layout=latent.layout, + generator=generator).to(latent.dtype) + + self.scheduler.set_timesteps(steps) + timesteps = self.scheduler.timesteps + + print("Running denoising loop...") + self.profile_start("denoise") + for step in timesteps: + latent = self.denoise(latent, prompt_embeddings, + neg_prompt_embeddings, step, scale) + + self.profile_end("denoise") + + latent = (latent / 1.5305) + 0.0609 + + self.profile_start("decode") + print("Decode denoised result...") + image = self.decode(latent) + self.profile_end("decode") + + torch.cuda.synchronize() + self.profile_end("run") + return image + + def print_summary(self, denoise_steps): + print('WARMUP\t{:>9.2f} ms'.format( + hip.hipEventElapsedTime(self.events['warmup'].start, + self.events['warmup'].end)[1])) + print('CLIP\t{:>9.2f} ms'.format( + hip.hipEventElapsedTime(self.events['clip'].start, + self.events['clip'].end)[1])) + print('mmditx{}\t{:>9.2f} ms'.format( + str(denoise_steps), + hip.hipEventElapsedTime(self.events['denoise'].start, + self.events['denoise'].end)[1])) + print('VAE-Dec\t{:>9.2f} ms'.format( + hip.hipEventElapsedTime(self.events['decode'].start, + self.events['decode'].end)[1])) + print('RUN\t{:>9.2f} ms'.format( + hip.hipEventElapsedTime(self.events['run'].start, + self.events['run'].end)[1])) + + @staticmethod + @measure + def load_mgx_model(name, + shapes, + onnx_model_path, + compiled_model_path=None, + use_fp16=False, + force_compile=False, + exhaustive_tune=False, + offload_copy=True, + batch=1): + print(f"Loading {name} model...") + if compiled_model_path is None: + compiled_model_path = onnx_model_path + onnx_file = f"{onnx_model_path}/{name}/model.onnx" + mxr_file = f"{compiled_model_path}/{name}/model_{'fp16' if use_fp16 else 'fp32'}_b{batch}_{'gpu' if not offload_copy else 'oc'}.mxr" + if not force_compile and os.path.isfile(mxr_file): + print(f"Found mxr, loading it from {mxr_file}") + model = mgx.load(mxr_file, format="msgpack") + elif os.path.isfile(onnx_file): + print(f"No mxr found at {mxr_file}") + print(f"Parsing from {onnx_file}") + model = mgx.parse_onnx(onnx_file, map_input_dims=shapes) + if use_fp16: + mgx.quantize_fp16(model) + model.compile(mgx.get_target("gpu"), + exhaustive_tune=exhaustive_tune, + offload_copy=offload_copy) + print(f"Saving {name} model to {mxr_file}") + os.makedirs(os.path.dirname(mxr_file), exist_ok=True) + mgx.save(model, mxr_file, format="msgpack") + else: + print( + f"No {name} model found at {onnx_file} or {mxr_file}. Please download it and re-try." + ) + sys.exit(1) + return model + + @measure + def tokenize(self, prompt): + return self.tokenizer.tokenize_with_weights(prompt) + + def encode_token_weights(self, model_name, token_weight_pairs): + tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + tokens = torch.tensor([tokens], dtype=torch.int64, device=self.device) + copy_tensor_sync(self.tensors[model_name]["input_ids"], + tokens.to(torch.int32)) + run_model_sync(self.models[model_name], self.model_args[model_name]) + encoder_out = self.tensors[model_name][get_output_name(0)] + encoder_out2 = None + if model_name != 't5xxl': + # flipped outputs for clip text encoders... + encoder_out2 = encoder_out + encoder_out = self.tensors[model_name][get_output_name(1)] + + if encoder_out2 is not None: + first_pooled = encoder_out2[0:1] + else: + first_pooled = encoder_out2 + output = [encoder_out[0:1]] + + return torch.cat(output, dim=-2), first_pooled + + @measure + def get_embeddings(self, prompt_tokens): + l_out, l_pooled = self.encode_token_weights("clip-l", + prompt_tokens["l"]) + # stable-diffusion-3-lite-onnx has swapped outputs for clip-l text encoder + if l_out.shape != (1, 77, 768): + l_out, l_pooled = l_pooled, l_out + + g_out, g_pooled = self.encode_token_weights("clip-g", + prompt_tokens["g"]) + if not self.skip_t5: + t5_out, _ = self.encode_token_weights("t5xxl", + prompt_tokens["t5xxl"]) + else: + t5_out = torch.zeros((1, 77, 4096)).cuda() + lg_out = torch.cat([l_out, g_out], dim=-1) + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + + return torch.cat([lg_out, t5_out], dim=-2), torch.cat( + (l_pooled, g_pooled), dim=-1) + + @staticmethod + def convert_to_rgb_image(image): + image = (image / 2 + 0.5).clamp(0, 1) + image = image.detach().cpu().permute(0, 2, 3, 1).numpy() + images = (image * 255).round().astype("uint8") + return [Image.fromarray(images[i]) for i in range(images.shape[0])] + + @staticmethod + def save_image(pil_image, filename="output.png"): + pil_image.save(filename) + + def CFGDenoiser(self, x, timestep, cond, uncond, cond_scale): + # Run cond and uncond in a batch together + x_concat = torch.cat([x, x]) + timestep_concat = timestep.expand([2]) + c_crossattn = torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]) + y = torch.cat([cond["y"], uncond["y"]]) + + copy_tensor_sync(self.tensors["mmdit"]["hidden_states"], x_concat) + copy_tensor_sync(self.tensors["mmdit"]["timestep"], timestep_concat) + copy_tensor_sync(self.tensors["mmdit"]["encoder_hidden_states"], + c_crossattn) + copy_tensor_sync(self.tensors["mmdit"]["pooled_projections"], y) + + run_model_sync(self.models["mmdit"], self.model_args['mmdit']) + + mmdit_out = self.tensors["mmdit"][get_output_name(0)] + + # Then split and apply CFG Scaling + pos_out, neg_out = torch.tensor_split(mmdit_out, 2) + + scaled = neg_out + (pos_out - neg_out) * cond_scale + + # scheduler step function requies all tensors be on the CPU + scaled = scaled.detach().clone().cpu() + scheduler_out = self.scheduler.step(model_output=scaled, + timestep=timestep, + sample=x, + return_dict=False)[0] + return scheduler_out + + def fix_cond(self, cond): + cond, pooled = (cond[0].cuda(), cond[1].cuda()) + return {"c_crossattn": cond, "y": pooled} + + def denoise(self, latent, conditioning, neg_cond, step, cfg_scale): + conditioning = self.fix_cond(conditioning) + neg_cond = self.fix_cond(neg_cond) + return self.CFGDenoiser(latent, step, conditioning, neg_cond, + cfg_scale) + + @measure + def decode(self, latents): + copy_tensor_sync(self.tensors["vae"]["latent_sample"], latents) + run_model_sync(self.models["vae"], self.model_args["vae"]) + return self.tensors["vae"][get_output_name(0)] + + @measure + def warmup(self, num_runs): + self.profile_start("warmup") + copy_tensor_sync(self.tensors["clip-l"]["input_ids"], + torch.ones((1, 77)).to(torch.int32)) + copy_tensor_sync(self.tensors["clip-g"]["input_ids"], + torch.ones((1, 77)).to(torch.int32)) + if not self.skip_t5: + copy_tensor_sync(self.tensors["t5xxl"]["input_ids"], + torch.ones((1, 77)).to(torch.int32)) + copy_tensor_sync( + self.tensors["mmdit"]["hidden_states"], + torch.randn((2 * self.batch, 16, 128, 128)).to(torch.float)) + copy_tensor_sync(self.tensors["mmdit"]["timestep"], + torch.randn((2 * self.batch)).to(torch.float)) + copy_tensor_sync( + self.tensors["mmdit"]["encoder_hidden_states"], + torch.randn((2 * self.batch, 154, 4096)).to(torch.float)) + copy_tensor_sync(self.tensors["mmdit"]["pooled_projections"], + torch.randn((2 * self.batch, 2048)).to(torch.float)) + copy_tensor_sync( + self.tensors["vae"]["latent_sample"], + torch.randn((self.batch, 16, 128, 128)).to(torch.float)) + + for _ in range(num_runs): + run_model_sync(self.models["clip-l"], self.model_args["clip-l"]) + run_model_sync(self.models["clip-g"], self.model_args["clip-g"]) + if not self.skip_t5: + run_model_sync(self.models["t5xxl"], self.model_args["t5xxl"]) + run_model_sync(self.models["mmdit"], self.model_args["mmdit"]) + run_model_sync(self.models["vae"], self.model_args["vae"]) + self.profile_end("warmup") + + +if __name__ == "__main__": + args = get_args() + + sd = StableDiffusionMGX(args.onnx_model_path, args.compiled_model_path, + args.fp16, args.batch, args.force_compile, + args.exhaustive_tune, args.skip_t5) + print("Warmup") + sd.warmup(5) + print("Run") + result = sd.run(args.prompt, args.negative_prompt, args.steps, args.seed, + args.scale) + + print("Summary") + sd.print_summary(args.steps) + print("Cleanup") + sd.cleanup() + + print("Convert result to rgb image...") + images = StableDiffusionMGX.convert_to_rgb_image(result) + for i, image in enumerate(images): + filename = f"{args.batch}_{args.output}" if args.output else f"output_s{args.seed}_t{args.steps}_{i}.png" + StableDiffusionMGX.save_image(image, filename) + print(f"Image saved to {filename}") diff --git a/hip-clang.docker b/hip-clang.docker index 8e3f9a9af28..6a2d57243c1 100755 --- a/hip-clang.docker +++ b/hip-clang.docker @@ -6,7 +6,7 @@ ARG PREFIX=/usr/local RUN dpkg --add-architecture i386 # Add rocm repository -RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/6.2/ focal main > /etc/apt/sources.list.d/rocm.list' +RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/6.3/ jammy main > /etc/apt/sources.list.d/rocm.list' # From docs.amd.com for installing rocm. Needed to install properly RUN sh -c "echo 'Package: *\nPin: release o=repo.radeon.com\nPin-priority: 600' > /etc/apt/preferences.d/rocm-pin-600" diff --git a/requirements.txt b/requirements.txt index 48ed2212bde..d3edaa80070 100755 --- a/requirements.txt +++ b/requirements.txt @@ -27,5 +27,5 @@ ROCm/half@rocm-5.6.0 pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/composable_kernel@57cdd70b7cb14e5e3b60cd9a5f96ba8dc343763e -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@e454b5d06fc2f099f7de3ee43450e7a6b1efe015 -DBUILD_FAT_LIBROCKCOMPILER=On +ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On +ROCm/rocMLIR@13065c4b3a216e1b13dfb8f746b8a0d421f124e8 -DBUILD_FAT_LIBROCKCOMPILER=On diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 909c0f6bc26..550bb30bd42 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -37,6 +37,7 @@ add_library(migraphx argument.cpp autocast_fp8.cpp auto_contiguous.cpp + base64.cpp common.cpp common_dims.cpp compile_src.cpp @@ -66,13 +67,14 @@ add_library(migraphx insert_pad.cpp instruction.cpp json.cpp - layout_nhwc.cpp + layout_convolution.cpp lexing.cpp load_save.cpp make_op.cpp memory_coloring.cpp module.cpp msgpack.cpp + netron_output.cpp normalize_attributes.cpp normalize_ops.cpp op_enums.cpp @@ -89,7 +91,6 @@ add_library(migraphx propagate_constant.cpp promote_literals.cpp quantization.cpp - quantize_fp16.cpp quantize_int4.cpp quantize_8bits.cpp reduce_dims.cpp @@ -115,6 +116,7 @@ add_library(migraphx split_single_dyn_dim.cpp target.cpp tmp_dir.cpp + truncate_float.cpp value.cpp verify_args.cpp ) @@ -122,6 +124,7 @@ add_library(migraphx if(WIN32) # Due to compilation crashing, we need to use type-erased matchers on Windows. target_compile_definitions(migraphx PUBLIC MIGRAPHX_USE_TYPE_ERASED_MATCHERS=1) + target_compile_options(migraphx PUBLIC "-mno-ms-bitfields") endif() configure_file(version.h.in include/migraphx/version.h) @@ -144,6 +147,7 @@ register_migraphx_ops( as_shape atanh atan + bit_cast bitwise_and broadcast broadcast_for_dot diff --git a/src/api/include/migraphx/migraphx.h b/src/api/include/migraphx/migraphx.h index 90ba7c3e017..1f1c05bb215 100644 --- a/src/api/include/migraphx/migraphx.h +++ b/src/api/include/migraphx/migraphx.h @@ -47,7 +47,9 @@ m(uint64_type, uint64_t) \ m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz) \ m(fp8e4m3fn_type, migraphx::fp8::fp8e4m3fn) \ - m(fp8e5m2_type, migraphx::fp8::fp8e5m2) + m(fp8e5m2_type, migraphx::fp8::fp8e5m2) \ + m(bf16_type, bf16) \ + m(fp8e5m2fnuz_type, migraphx::fp8::fp8e5m2fnuz) // clang-format on #ifdef __cplusplus diff --git a/src/base64.cpp b/src/base64.cpp new file mode 100644 index 00000000000..0d08b1f6220 --- /dev/null +++ b/src/base64.cpp @@ -0,0 +1,81 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +namespace { +using byte = unsigned char; + +std::array constexpr b64_chars{ + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', + 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', + 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', + 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'}; + +/// base64 encoder snippet altered from https://stackoverflow.com/a/37109258 +std::string encode(const std::vector& buf) +{ + std::size_t len = buf.size(); + std::vector res_vec((len + 2) / 3 * 4, '='); + std::size_t j = 0; + std::size_t remaining = len % 3; + const size_t last = len - remaining; + + for(size_t i = 0; i < last; i += 3) + { + std::size_t n = static_cast(buf.at(i)) << 16u | + static_cast(buf.at(i + 1)) << 8u | + static_cast(buf.at(i + 2)); + res_vec.at(j++) = b64_chars.at(n >> 18u); + res_vec.at(j++) = b64_chars.at(n >> 12u & 0x3Fu); + res_vec.at(j++) = b64_chars.at(n >> 6u & 0x3Fu); + res_vec.at(j++) = b64_chars.at(n & 0x3Fu); + } + // Set padding + if(remaining != 0) + { + std::size_t n = --remaining == 0 ? static_cast(buf.at(last)) + : static_cast(buf.at(last)) << 8u | + static_cast(buf.at(last + 1)); + res_vec.at(j++) = b64_chars.at(remaining == 0 ? n >> 2u : n >> 10u & 0x3Fu); + res_vec.at(j++) = b64_chars.at(remaining == 0 ? n << 4u & 0x3Fu : n >> 4u & 0x03Fu); + res_vec.at(j++) = remaining == 0 ? '=' : b64_chars.at(n << 2u & 0x3Fu); + } + return {res_vec.begin(), res_vec.end()}; +} + +} // namespace + +std::string base64_encode(const std::string& str) +{ + return encode(std::vector(str.begin(), str.end())); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/common_dims.cpp b/src/common_dims.cpp index 3c3e8db1c3a..1afe92087fd 100644 --- a/src/common_dims.cpp +++ b/src/common_dims.cpp @@ -94,7 +94,7 @@ static bool compute_common_dim(std::vector& cd_dims, common_dim_state& state1, common_dim_state& state2) { - assert(state1.get() <= state2.get()); + assert(state1.get() < state2.get()); auto d2 = state2.get(); auto dims = state1.dims_for(d2); auto n = elements(dims); @@ -131,7 +131,17 @@ common_dims common_dims::compute(const std::vector& dims1, { auto d1 = state1.get(); auto d2 = state2.get(); - if(d1 <= d2) + if(d1 == d2) + { + state1.add_axes(1, cd.dims.size()); + state2.add_axes(1, cd.dims.size()); + state1.rem = 1; + state2.rem = 1; + cd.dims.push_back(d1); + state1.next(); + state2.next(); + } + else if(d1 < d2) { if(not compute_common_dim(cd.dims, state1, state2)) return {}; diff --git a/src/driver/main.cpp b/src/driver/main.cpp index 04fa0cfe3bc..01d59804e2d 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -56,6 +56,8 @@ #include #include +#include + #include namespace migraphx { @@ -166,6 +168,10 @@ struct loader {"--binary"}, ap.help("Print out program in binary format."), ap.set_value("binary")); + ap(output_type, + {"--netron"}, + ap.help("Print out program as Netron readable json."), + ap.set_value("netron")); ap(output, {"--output", "-o"}, ap.help("Output to file.")); } @@ -418,6 +424,8 @@ struct loader *os << to_json_string(p.to_value()) << std::endl; else if(type == "binary") write(*os, save_buffer(p)); + else if(type == "netron") + *os << make_netron_output(p) << std::endl; } }; @@ -482,6 +490,7 @@ struct compiler compiler_target ct; compile_options co; bool to_fp16 = false; + bool to_bf16 = false; bool to_fp8 = false; bool to_int8 = false; bool to_int4 = false; @@ -506,6 +515,7 @@ struct compiler ap.help("Exhastively search for best tuning parameters for kernels"), ap.set_value(true)); ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true)); + ap(to_bf16, {"--bf16"}, ap.help("Quantize for bf16"), ap.set_value(true)); ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true)); ap(to_fp8, {"--fp8"}, ap.help("Quantize for fp8"), ap.set_value(true)); ap(to_int4, {"--int4-weights"}, ap.help("Quantize weights for int4"), ap.set_value(true)); @@ -555,6 +565,10 @@ struct compiler { quantize_fp16(p); } + if(to_bf16) + { + quantize_bf16(p); + } if(to_int8) { quantize_int8(p, t, {host_params(p)}); @@ -639,6 +653,10 @@ struct verify : command { vo.quantize = precision::fp16; } + if(c.to_bf16) + { + vo.quantize = precision::bf16; + } if(c.to_int8) { vo.quantize = precision::int8; diff --git a/src/driver/precision.hpp b/src/driver/precision.hpp index d7d7cecf00e..9ed1f402f9d 100644 --- a/src/driver/precision.hpp +++ b/src/driver/precision.hpp @@ -32,6 +32,7 @@ enum class precision { fp32, fp16, + bf16, int8 }; diff --git a/src/driver/verify.cpp b/src/driver/verify.cpp index 92bae3eee86..14f9e71f70f 100644 --- a/src/driver/verify.cpp +++ b/src/driver/verify.cpp @@ -50,11 +50,14 @@ verify::tolerance get_tolerances(const program& p, std::optional atol, std::optional rtol) { - bool has_fp16 = any_of(p.get_modules(), [](auto&& m) { - return any_of(*m, [](auto&& ins) { return (ins.get_shape().type() == shape::half_type); }); + bool has_16bit = any_of(p.get_modules(), [](auto&& m) { + return any_of(*m, [](auto&& ins) { + return (ins.get_shape().type() == shape::half_type or + ins.get_shape().type() == shape::bf16_type); + }); }); migraphx::verify::tolerance result{}; - if(has_fp16 or vo.quantize == precision::fp16) + if(has_16bit or vo.quantize == precision::fp16 or vo.quantize == precision::bf16) { result.rms_tol = 8e-2; result.atol = 4e-2; @@ -100,6 +103,10 @@ std::vector run_target(program p, { quantize_fp16(p); } + if(vo.quantize == precision::bf16) + { + quantize_bf16(p); + } p.compile(t, options); parameter_map m; diff --git a/src/fuse_pointwise_reduce.cpp b/src/fuse_pointwise_reduce.cpp index dfd3f474ba2..eb2feb565b7 100644 --- a/src/fuse_pointwise_reduce.cpp +++ b/src/fuse_pointwise_reduce.cpp @@ -26,9 +26,20 @@ #include #include #include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_SPLIT_REDUCE_SIZE); + +static std::size_t get_split_size(std::size_t default_split) +{ + std::string value = string_value_of(MIGRAPHX_SPLIT_REDUCE_SIZE{}); + if(value.empty()) + return default_split; + return std::stoul(value); +} void fuse_pointwise_reduce::apply(module_pass_manager& mpm) const { @@ -36,6 +47,8 @@ void fuse_pointwise_reduce::apply(module_pass_manager& mpm) const mpm.run_pass(fuse_reduce{.enable_rewrite_reshapes = false}); mpm.run_pass(fuse_pointwise{.enable_rewrite_reshapes = true}); mpm.run_pass(fuse_reduce{.enable_rewrite_reshapes = true}); + mpm.run_pass(split_reduce{.split_size = get_split_size(split_size)}); + mpm.run_pass(fuse_pointwise{.enable_rewrite_broadcasts = true}); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index dc564acfbcb..5be8fd4a25e 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -174,12 +174,53 @@ static auto any_input(Ms... ms) return match::any_of[match::inputs()](match::any(ms...).bind("input")); } +bool is_valid_broadcast(const instruction_ref b, const std::vector& reduce_axes) +{ + std::vector broadcast_axes; + auto bstrides = b->get_shape().strides(); + + for(size_t i = 0; i < bstrides.size(); ++i) + { + if(bstrides.at(i) == 0) + broadcast_axes.push_back(i); + } + + return broadcast_axes == reduce_axes; +} + +template +static auto match_broadcast_axes(M m) +{ + return match::make_basic_fun_matcher( + [=](match::matcher_context& ctx, instruction_ref ins) -> optional { + optional result = m.match(ctx, ins); + if(contains(ctx.instructions, "broadcast")) + { + instruction_ref reduce; + if(ins->get_operator().name() == "fused_reduce") + { + reduce = ins; + } + else + { + assert(contains(ctx.instructions, "reduce")); + reduce = ctx.instructions["reduce"]; + } + auto axes = reduce->get_operator().to_value().at("axes").to_vector(); + auto broadcast = ctx.instructions["broadcast"]; + if(not is_valid_broadcast(broadcast, axes)) + return nullopt; + } + return result; + }); +} + static auto match_broadcastable_input(const std::string& op, const std::string& name) { auto match_op = match::name(op)(used_once_except_broadcast()).bind(name); auto match_op_input = any_input(match_op, match::used_once()); auto broadcast_match_op_input = any_input(match_broadcast(match_op), match::used_once()); - return match::any_of(match_op_input, broadcast_match_op_input); + return match::any_of(match_op_input, match_broadcast_axes(broadcast_match_op_input)); } static void finalize_reduce_module(module_ref m) diff --git a/src/include/migraphx/base64.hpp b/src/include/migraphx/base64.hpp new file mode 100644 index 00000000000..36035430826 --- /dev/null +++ b/src/include/migraphx/base64.hpp @@ -0,0 +1,39 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_RTGLIB_BASE64_HPP +#define MIGRAPHX_GUARD_RTGLIB_BASE64_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +/// encode string to base64 +std::string base64_encode(const std::string& str); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/bf16.hpp b/src/include/migraphx/bf16.hpp new file mode 100644 index 00000000000..26ecdd7c996 --- /dev/null +++ b/src/include/migraphx/bf16.hpp @@ -0,0 +1,39 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#ifndef MIGRAPHX_GUARD_RTGLIB_BF16_HPP +#define MIGRAPHX_GUARD_RTGLIB_BF16_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +using bf16 = migraphx::generic_float<7, 8>; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/bit_cast.hpp b/src/include/migraphx/bit_cast.hpp index 951b34bc340..fc4aab2e3b6 100644 --- a/src/include/migraphx/bit_cast.hpp +++ b/src/include/migraphx/bit_cast.hpp @@ -25,6 +25,7 @@ #if defined(__GNUC__) && !defined(__clang__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wduplicated-branches" #endif #include diff --git a/src/include/migraphx/float8.hpp b/src/include/migraphx/float8.hpp index 445b8ebeb1e..98a4a7b10fa 100644 --- a/src/include/migraphx/float8.hpp +++ b/src/include/migraphx/float8.hpp @@ -42,6 +42,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -379,52 +380,73 @@ class numeric_limits // ================================================================================================= // define numeric limits for the new data type -// NOLINTBEGIN +// NOLINTBEGIN(cert-dcl58-cpp) namespace std { -#define MIGRAPHX_FP8_STD_OVERLOADS(T) \ - inline bool isfinite(T x) { return not x.is_inf() and not x.is_nan(); } \ - inline bool isnan(T x) { return x.is_nan(); } \ - template <> \ - class numeric_limits : public migraphx::fp8::numeric_limits \ - { \ - }; \ - template \ - struct common_type : std::common_type \ - { \ - }; \ - template \ - struct common_type : std::common_type \ - { \ - }; \ - template <> \ - struct common_type \ - { \ - using type = T; \ - }; -MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fn) -MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2) -MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fnuz) -MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2fnuz) - -// needed to resolve between multiple ambiguous definition from previous templates -#define MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(T, U) \ - template <> \ - struct common_type : std::common_type \ - { \ - }; \ - template <> \ - struct common_type : std::common_type \ - { \ - }; +template +inline bool isfinite(migraphx::fp8::float8 x) +{ + return not x.is_inf() and not x.is_nan(); +} + +template +inline bool isnan(migraphx::fp8::float8 x) +{ + return x.is_nan(); +} + +template +class numeric_limits> + : public migraphx::fp8::numeric_limits> +{ +}; +template +struct common_type, U> : std::common_type +{ +}; +template +struct common_type> : std::common_type +{ +}; +template +struct common_type, migraphx::fp8::float8> +{ + using type = migraphx::fp8::float8; +}; + +template +struct common_type, migraphx::fp8::float8> +{ + using type = float; +}; + +template +struct common_type, + migraphx::fp8::float8> +{ + using type = float; +}; + +template +struct common_type, + migraphx::generic_float> +{ + using type = float; +}; + +template +struct common_type, migraphx::fp8::float8> + : std::common_type +{ +}; + +template +struct common_type, migraphx::generic_float> + : std::common_type +{ +}; -MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e4m3fn, migraphx::fp8::fp8e5m2) -MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e4m3fn, migraphx::fp8::fp8e4m3fnuz) -MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e4m3fn, migraphx::fp8::fp8e5m2fnuz) -MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e5m2, migraphx::fp8::fp8e4m3fnuz) -MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e5m2, migraphx::fp8::fp8e5m2fnuz) -MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e4m3fnuz, migraphx::fp8::fp8e5m2fnuz) } // namespace std -// NOLINTEND +// NOLINTEND(cert-dcl58-cpp) // ================================================================================================= #endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP diff --git a/src/include/migraphx/fp8_types.hpp b/src/include/migraphx/fp8_types.hpp index 7a6728b7ec4..d0ea85a6ece 100644 --- a/src/include/migraphx/fp8_types.hpp +++ b/src/include/migraphx/fp8_types.hpp @@ -28,8 +28,10 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { struct fp8_types { - const std::set types = { - shape::fp8e4m3fnuz_type, shape::fp8e4m3fn_type, shape::fp8e5m2_type}; + const std::set types = {shape::fp8e4m3fnuz_type, + shape::fp8e5m2fnuz_type, + shape::fp8e4m3fn_type, + shape::fp8e5m2_type}; std::set get() const { return types; } }; diff --git a/src/include/migraphx/fuse_pointwise_reduce.hpp b/src/include/migraphx/fuse_pointwise_reduce.hpp index 68bdc4e9951..63d78d2360b 100644 --- a/src/include/migraphx/fuse_pointwise_reduce.hpp +++ b/src/include/migraphx/fuse_pointwise_reduce.hpp @@ -35,6 +35,7 @@ struct module_pass_manager; struct MIGRAPHX_EXPORT fuse_pointwise_reduce { + std::size_t split_size = 32768; std::string name() const { return "fuse_pointwise_reduce"; } void apply(module_pass_manager& mpm) const; }; diff --git a/src/include/migraphx/generic_float.hpp b/src/include/migraphx/generic_float.hpp new file mode 100644 index 00000000000..c886bc035c9 --- /dev/null +++ b/src/include/migraphx/generic_float.hpp @@ -0,0 +1,476 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#ifndef MIGRAPHX_GUARD_MIGRAPHX_GENERIC_FLOAT_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_GENERIC_FLOAT_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +constexpr unsigned int all_ones() noexcept +{ + return (1u << N) - 1u; +} + +template +constexpr int countl_zero(T value) +{ + unsigned int r = 0; + for(; value != 0u; value >>= 1u) + r++; + return 8 * sizeof(value) - r; +} + +constexpr std::size_t bit_ceil(std::size_t v) +{ + if(v <= 1) + return 1; + v--; + v |= v >> 1u; + v |= v >> 2u; + v |= v >> 4u; + v |= v >> 8u; + v |= v >> 16u; + v |= v >> 32u; + return v + 1; +} + +constexpr std::size_t integer_divide_ceil(std::size_t x, std::size_t y) +{ + return (x + y - std::size_t{1}) / y; +} + +template +struct unsigned_type +{ +}; + +template <> +struct unsigned_type<1> +{ + using type = std::uint8_t; +}; + +template <> +struct unsigned_type<2> +{ + using type = std::uint16_t; +}; + +template <> +struct unsigned_type<4> +{ + using type = std::uint32_t; +}; + +template <> +struct unsigned_type<8> +{ + using type = std::uint64_t; +}; + +struct float32_parts +{ + unsigned int mantissa : 23; + unsigned int exponent : 8; + unsigned int sign : 1; + + static constexpr unsigned int exponent_width() { return 8; } + + static constexpr unsigned int mantissa_width() { return 23; } + + static constexpr unsigned int max_exponent() { return all_ones<8>(); } + + static constexpr int exponent_bias() { return all_ones<7>(); } + + constexpr float to_float() const noexcept { return migraphx::bit_cast(*this); } +}; + +constexpr float32_parts get_parts(float f) { return migraphx::bit_cast(f); } + +template +struct __attribute__((packed, may_alias)) generic_float +{ + using type = typename unsigned_type::type; + + type mantissa : MantissaSize; + type exponent : ExponentSize; + type sign : 1; + + static constexpr int exponent_bias() { return all_ones(); } + + explicit constexpr generic_float(float f = 0.0) noexcept { from_float(get_parts(f)); } + + constexpr generic_float& operator=(float f) noexcept + { + from_float(get_parts(f)); + return *this; + } + + constexpr generic_float operator-() const noexcept + { + generic_float result = *this; + result.sign = not this->sign; + return result; + } + + constexpr generic_float operator+() const noexcept { return *this; } + + constexpr float to_float() const noexcept + { + float32_parts f{}; + f.sign = sign; + + if(exponent == 0 and ExponentSize != float32_parts::exponent_width()) // subnormal fps + { + + if(mantissa == 0) + { + f.exponent = 0; + f.mantissa = 0; + } + else + { + type shift = 0; + f.mantissa = mantissa; + + if(MantissaSize < float32_parts::mantissa_width()) + { + shift = MantissaSize - ((sizeof(type) * 8) - countl_zero(mantissa)); + f.mantissa <<= (shift + 1u); + } + + f.exponent = float32_parts::exponent_bias() - exponent_bias() - shift; + f.mantissa = f.mantissa << (float32_parts::mantissa_width() - MantissaSize); + } + } + else if(exponent == all_ones()) + { + f.mantissa = mantissa << (float32_parts::mantissa_width() - MantissaSize); + f.exponent = float32_parts::max_exponent(); + } + else + { + f.mantissa = mantissa << (float32_parts::mantissa_width() - MantissaSize); + constexpr const int diff = float32_parts::exponent_bias() - exponent_bias(); + f.exponent = int(exponent) + diff; + } + + return f.to_float(); + } + + constexpr void from_float(float32_parts f) noexcept + { + sign = f.sign; + + if(f.exponent == 0) + { + exponent = 0; + mantissa = f.mantissa >> (float32_parts::mantissa_width() - MantissaSize); + } + else if(f.exponent == float32_parts::max_exponent()) + { + exponent = all_ones(); + mantissa = f.mantissa >> (float32_parts::mantissa_width() - MantissaSize); + } + else + { + constexpr const int diff = float32_parts::exponent_bias() - exponent_bias(); + auto e = int(f.exponent) - diff; + + if(e >= static_cast(all_ones())) + { + exponent = all_ones(); + mantissa = 0; + } + else if(e < 1) + { + exponent = 0; + + auto shift = diff - int(f.exponent); + auto shift_amount = shift + (float32_parts::mantissa_width() - MantissaSize) + 1; + + if(shift_amount < (sizeof(unsigned int) * 8)) + { + mantissa = (f.mantissa | (1u << float32_parts::mantissa_width())) >> + (shift + (float32_parts::mantissa_width() - MantissaSize) + 1); + } + else + { + mantissa = 0; + } + } + else + { + exponent = int(f.exponent) - diff; + mantissa = f.mantissa >> (float32_parts::mantissa_width() - MantissaSize); + } + } + + exponent = std::min(exponent, all_ones()); + } + + constexpr bool is_normal() const noexcept + { + return exponent != all_ones() and exponent != 0; + } + + constexpr bool is_inf() const noexcept + { + return exponent == all_ones() and mantissa == 0; + } + + constexpr bool is_nan() const noexcept + { + return exponent == all_ones() and mantissa != 0; + } + + constexpr bool is_finite() const noexcept { return exponent != all_ones(); } + + constexpr operator float() const noexcept { return this->to_float(); } + + static constexpr generic_float infinity() + { + generic_float x{}; + x.exponent = all_ones(); + return x; + } + + static constexpr generic_float snan() + { + generic_float x{}; + x.exponent = all_ones(); + x.mantissa = 1u << (MantissaSize - 2u); + return x; + } + + static constexpr generic_float qnan() + { + generic_float x{}; + x.exponent = all_ones(); + x.mantissa = 1u << (MantissaSize - 1u); + return x; + } + + static constexpr generic_float min() + { + generic_float x{}; + x.exponent = 1; + x.mantissa = 0; + return x; + } + + static constexpr generic_float denorm_min() + { + generic_float x{}; + x.exponent = 0; + x.mantissa = 1; + x.sign = 0; + return x; + } + + static constexpr generic_float lowest() + { + generic_float x{}; + x.exponent = all_ones() - 1; + x.mantissa = all_ones(); + x.sign = 1; + return x; + } + + static constexpr generic_float max() + { + generic_float x{}; + x.exponent = all_ones() - 1; + x.mantissa = all_ones(); + x.sign = 0; + return x; + } + + static constexpr generic_float epsilon() + { + generic_float x{1.0}; + x.mantissa++; + return generic_float{x.to_float() - 1.0f}; + } +// NOLINTNEXTLINE +#define MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(op) \ + constexpr generic_float& operator op(const generic_float & rhs) \ + { \ + float self = *this; \ + float frhs = rhs; \ + self op frhs; \ + *this = generic_float(self); \ + return *this; \ + } + MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(*=) + MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(-=) + MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(+=) + MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(/=) +// NOLINTNEXTLINE +#define MIGRAPHX_GENERIC_FLOAT_BINARY_OP(op) \ + friend constexpr generic_float operator op(const generic_float& x, const generic_float& y) \ + { \ + return generic_float(float(x) op float(y)); \ + } + MIGRAPHX_GENERIC_FLOAT_BINARY_OP(*) + MIGRAPHX_GENERIC_FLOAT_BINARY_OP(-) + MIGRAPHX_GENERIC_FLOAT_BINARY_OP(+) + MIGRAPHX_GENERIC_FLOAT_BINARY_OP(/) +// NOLINTNEXTLINE +#define MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(op) \ + friend constexpr bool operator op(const generic_float& x, const generic_float& y) \ + { \ + return float(x) op float(y); \ + } + MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<) + MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<=) + MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>) + MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>=) + + friend constexpr bool operator==(const generic_float& x, const generic_float& y) + { + if(not x.is_finite() or not y.is_finite()) + return false; + + if((x.mantissa == 0 and x.exponent == 0) and (y.mantissa == 0 and y.exponent == 0)) + { + return true; + } + + return std::tie(x.mantissa, x.exponent, x.sign) == std::tie(y.mantissa, y.exponent, y.sign); + } + + friend constexpr bool operator!=(const generic_float& x, const generic_float& y) + { + return not(x == y); + } + + constexpr generic_float& operator++() noexcept + { + *this += generic_float(1.0f); + return *this; + } + + const generic_float operator++(int) noexcept // NOLINT(readability-const-return-type) + { + generic_float temp = *this; + *this += generic_float(1.0f); + return temp; + } +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +// NOLINTBEGIN(cert-dcl58-cpp) +namespace std { + +template +class numeric_limits> +{ + public: + static constexpr bool has_infinity = true; + static constexpr migraphx::generic_float epsilon() + { + return migraphx::generic_float::epsilon(); + } + + static constexpr migraphx::generic_float quiet_NaN() + { + return migraphx::generic_float::qnan(); + } + + static constexpr migraphx::generic_float signaling_NaN() + { + return migraphx::generic_float::snan(); + } + + static constexpr migraphx::generic_float max() + { + return migraphx::generic_float::max(); + } + + static constexpr migraphx::generic_float min() + { + return migraphx::generic_float::min(); + } + + static constexpr migraphx::generic_float lowest() + { + return migraphx::generic_float::lowest(); + } + + static constexpr migraphx::generic_float infinity() + { + return migraphx::generic_float::infinity(); + } + + static constexpr migraphx::generic_float denorm_min() + { + return migraphx::generic_float::denorm_min(); + } +}; + +template +struct common_type, T> : std::common_type +{ +}; + +template +struct common_type> : std::common_type +{ +}; + +template +struct common_type, migraphx::generic_float> +{ + using type = migraphx::generic_float; +}; + +template +struct common_type, migraphx::generic_float> +{ + using type = float; +}; + +} // namespace std +// NOLINTEND(cert-dcl58-cpp) + +#endif // MIGRAPHX_GUARD_MIGRAPHX_GENERIC_FLOAT_HPP diff --git a/src/include/migraphx/half.hpp b/src/include/migraphx/half.hpp index 3296e8c328d..b92942557a4 100644 --- a/src/include/migraphx/half.hpp +++ b/src/include/migraphx/half.hpp @@ -25,14 +25,14 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_HALF_HPP #define MIGRAPHX_GUARD_RTGLIB_HALF_HPP -#include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -using half = half_float::half; +using half = migraphx::generic_float<10, 5>; namespace detail { template @@ -40,14 +40,6 @@ struct deduce { using type = T; }; - -#ifdef HAS_HALF_V1 -template <> -struct deduce -{ - using type = half; -}; -#endif } // namespace detail template @@ -56,60 +48,4 @@ using deduce = typename detail::deduce::type; } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx -namespace std { - -template -struct common_type : std::common_type // NOLINT -{ -}; - -template -struct common_type : std::common_type // NOLINT -{ -}; - -template <> -struct common_type -{ - using type = float; -}; - -template <> -struct common_type -{ - using type = float; -}; - -template <> -struct common_type -{ - using type = float; -}; - -template <> -struct common_type -{ - using type = float; -}; - -template <> -struct common_type -{ - using type = float; -}; - -template <> -struct common_type -{ - using type = float; -}; - -template <> -struct common_type -{ - using type = migraphx::half; -}; - -} // namespace std - #endif diff --git a/src/include/migraphx/layout_nhwc.hpp b/src/include/migraphx/layout_convolution.hpp similarity index 81% rename from src/include/migraphx/layout_nhwc.hpp rename to src/include/migraphx/layout_convolution.hpp index faf097a4d9d..9e45033a8db 100644 --- a/src/include/migraphx/layout_nhwc.hpp +++ b/src/include/migraphx/layout_convolution.hpp @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP -#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP +#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_CONVOLUTION_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_CONVOLUTION_HPP #include #include @@ -34,14 +34,15 @@ inline namespace MIGRAPHX_INLINE_NS { struct module_pass_manager; /** - * Transform convolutions to nhwc + * Transform convolutions layout */ -struct MIGRAPHX_EXPORT layout_nhwc +struct MIGRAPHX_EXPORT layout_convolution { - std::string name() const { return "layout_nhwc"; } + bool channels_last = false; + std::string name() const { return "layout_convolution"; } void apply(module_pass_manager& mpm) const; }; } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx -#endif // MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP +#endif // MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_CONVOLUTION_HPP diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index fdada026e66..0babe316dd1 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -378,6 +378,12 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m) return result; } +template +bool instruction_matches(module& mod, instruction_ref ins, M&& m) +{ + return match_instruction(mod, ins, std::forward(m)).result != mod.end(); +} + /// Find first instance of a matching instruction in a module template match::matcher_result find_match(module& modl, M&& m) diff --git a/src/include/migraphx/netron_output.hpp b/src/include/migraphx/netron_output.hpp new file mode 100644 index 00000000000..fb355a2d9f5 --- /dev/null +++ b/src/include/migraphx/netron_output.hpp @@ -0,0 +1,39 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_RTGLIB_NETRON_OUTPUT_HPP +#define MIGRAPHX_GUARD_RTGLIB_NETRON_OUTPUT_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +MIGRAPHX_EXPORT std::string make_netron_output(const program& prog); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/bit_cast.hpp b/src/include/migraphx/op/bit_cast.hpp new file mode 100644 index 00000000000..eb233ad8b36 --- /dev/null +++ b/src/include/migraphx/op/bit_cast.hpp @@ -0,0 +1,104 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_OPERATORS_BIT_CAST_HPP +#define MIGRAPHX_GUARD_OPERATORS_BIT_CAST_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +/** + * Obtain a value of type `target_type` by reinterpreting + * the object represnetaion of the input. Originally used + * for casting from fp8e4m3fn to fp8e4m3fnuz. + */ +struct bit_cast : unary +{ + shape::type_t target_type; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.target_type, "target_type")); + } + + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this, true}.has(1); + auto input = inputs.at(0); + std::size_t target_type_size; + shape::visit(target_type, [&](auto as) { target_type_size = as.size(); }); + if(input.type_size() != target_type_size) + { + MIGRAPHX_THROW("BIT_CAST: target_type has different type_size from input's"); + } + if(input.dynamic()) + { + return {target_type, input.dyn_dims()}; + } + else + { + return {target_type, input.lens(), input.strides()}; + } + } + + std::string point_op() const + { + return "${function:bit_cast}<" + shape::cpp_type(target_type) + ">(${0})"; + } + + argument compute(const dyn_output& dyn_out, std::vector args) const + { + argument result{dyn_out.computed_shape}; + result.visit([&](auto output) { + using otype = typename decltype(output)::value_type; + args[0].visit([&](auto input) { + using itype = typename decltype(input)::value_type; + if constexpr(sizeof(otype) == sizeof(itype)) + { + par_transform(input.begin(), input.end(), output.begin(), [&](auto x) { + return migraphx::bit_cast(x); + }); + } + else + { + // not possible to hit this unless somehow the types change after compute_shape + // is called + MIGRAPHX_THROW("BIT_CAST: type size mismatch"); + } + }); + }); + return result; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/dequantizelinear.hpp b/src/include/migraphx/op/dequantizelinear.hpp index 3cd2d89fd96..60500b168d6 100644 --- a/src/include/migraphx/op/dequantizelinear.hpp +++ b/src/include/migraphx/op/dequantizelinear.hpp @@ -54,7 +54,7 @@ struct dequantizelinear { MIGRAPHX_THROW("DEQUANTIZELINEAR: Zero point and input should be the same type."); } - return {inputs[1].type(), inputs[0].lens(), inputs[0].strides()}; + return inputs[0].with_lens(inputs[1].type(), inputs[0].lens()); } argument compute(const shape& output_shape, std::vector args) const diff --git a/src/include/migraphx/op/flatten.hpp b/src/include/migraphx/op/flatten.hpp index 7f36b7623e3..55fab0d33f8 100644 --- a/src/include/migraphx/op/flatten.hpp +++ b/src/include/migraphx/op/flatten.hpp @@ -80,7 +80,6 @@ struct flatten } else { - check_shapes{inputs, *this}.standard(); auto&& lens = s.lens(); auto x = std::accumulate( lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{}); @@ -91,9 +90,14 @@ struct flatten } argument compute(const dyn_output& dyn_out, std::vector args) const { - return args[0].reshape(dyn_out.computed_shape); + assert(dyn_out.computed_shape.standard()); + argument result{dyn_out.computed_shape}; + + visit_all(result, args[0])([&](auto output, auto input) { + std::copy(input.begin(), input.end(), output.begin()); + }); + return result; } - std::ptrdiff_t output_alias(const std::vector&) const { return 0; } }; } // namespace op diff --git a/src/include/migraphx/op/quant_convolution.hpp b/src/include/migraphx/op/quant_convolution.hpp index 323a132ad4b..22cbd124b93 100644 --- a/src/include/migraphx/op/quant_convolution.hpp +++ b/src/include/migraphx/op/quant_convolution.hpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -89,8 +90,8 @@ struct quant_convolution // all input type must be int8_type or fp8 types // output should be float_type - std::set supported_types = { - shape::int8_type, shape::fp8e4m3fnuz_type, shape::fp8e4m3fn_type, shape::fp8e5m2_type}; + std::set supported_types = fp8_types{}.get(); + supported_types.insert(shape::int8_type); if(not contains(supported_types, t)) { MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8 or fp8"); diff --git a/src/include/migraphx/op/quant_dot.hpp b/src/include/migraphx/op/quant_dot.hpp index b697851fe9d..e74daca2988 100644 --- a/src/include/migraphx/op/quant_dot.hpp +++ b/src/include/migraphx/op/quant_dot.hpp @@ -45,11 +45,9 @@ struct quant_dot const shape& a = inputs.at(0); const shape& b = inputs.at(1); auto t = a.type(); - std::set supported_types = {shape::int8_type, - shape::uint8_type, - shape::fp8e4m3fnuz_type, - shape::fp8e4m3fn_type, - shape::fp8e5m2_type}; + std::set supported_types = fp8_types{}.get(); + supported_types.insert(shape::int8_type); + supported_types.insert(shape::uint8_type); if(not contains(supported_types, t)) { MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t, uint8_t and fp8 types"); diff --git a/src/include/migraphx/op/quantizelinear.hpp b/src/include/migraphx/op/quantizelinear.hpp index 77208444bfa..7a0de31cf5a 100644 --- a/src/include/migraphx/op/quantizelinear.hpp +++ b/src/include/migraphx/op/quantizelinear.hpp @@ -63,13 +63,9 @@ struct quantizelinear } if(inputs.size() == 3) { - return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()}; + return inputs[0].with_lens(inputs[2].type(), inputs[0].lens()); } - if(out_type.has_value()) - { - return {out_type.value(), inputs[0].lens(), inputs[0].strides()}; - } - return {shape::uint8_type, inputs[0].lens(), inputs[0].strides()}; + return inputs[0].with_lens(out_type.value_or(shape::uint8_type), inputs[0].lens()); } argument compute(const shape& output_shape, std::vector args) const diff --git a/src/include/migraphx/output_iterator.hpp b/src/include/migraphx/output_iterator.hpp index 7aced4a08a3..e4d670b8537 100644 --- a/src/include/migraphx/output_iterator.hpp +++ b/src/include/migraphx/output_iterator.hpp @@ -72,6 +72,12 @@ auto join_back_inserter(Container& c) [&](const auto& r) { c.insert(c.end(), r.begin(), r.end()); }); } +template +auto push_inserter(Container& c) +{ + return make_function_output_iterator([&](const auto& x) { c.push(x); }); +} } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx + #endif // MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP diff --git a/src/include/migraphx/quantization.hpp b/src/include/migraphx/quantization.hpp index d849023b6cf..eead5e40ba1 100644 --- a/src/include/migraphx/quantization.hpp +++ b/src/include/migraphx/quantization.hpp @@ -51,6 +51,9 @@ quantize_fp8(program& prog, const target& t, const std::vector& c MIGRAPHX_EXPORT void quantize_int4_weights(program& prog); +MIGRAPHX_EXPORT void quantize_bf16(program& prog, + const std::vector& ins_names = {"all"}); + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index e44413a09a1..c064091c279 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -33,6 +33,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -72,18 +73,19 @@ struct rewrite_reshapes auto matcher() const { - auto reshape = - match::name("reshape", "squeeze", "unsqueeze", "flatten")(match::used_once()); - auto skip_contiguous_broadcast = - match::skip(match::name("contiguous", "multibroadcast")(match::used_once())); - auto skip_contiguous_broadcast_arg = [&](auto... ms) { - return match::arg(0)(skip_contiguous_broadcast(ms...)); - }; + auto reshapes = match::name("reshape", + "squeeze", + "unsqueeze", + "flatten", + "transpose", + "contiguous", + "multibroadcast", + "broadcast")(match::used_once()); auto pointwise = match::name(op1)(match::used_once()); - auto reshape_pointwise = - reshape(skip_contiguous_broadcast_arg(pointwise.bind("x"))).bind("reshape"); - return match::name(op2)(match::any_of[match::inputs()]( - skip_contiguous_broadcast(reshape_pointwise).bind("input"))); + auto reshapes_pointwise = + reshapes(match::arg(0)(match::skip(reshapes())(pointwise.bind("x")))); + return match::name(op2)( + match::any_of[match::inputs()](reshapes_pointwise.bind("input"))); } template @@ -100,6 +102,12 @@ struct rewrite_reshapes return last; } + template + static bool any_input_of(instruction_ref start, instruction_ref last, F f) + { + return find_input_if(start, last, f) != last; + } + static bool match_input(instruction_ref ins, instruction_ref x_ins) { if(ins->inputs().empty()) @@ -120,23 +128,19 @@ struct rewrite_reshapes return result; } + static bool is_broadcast(instruction_ref ins) { return ins->name() == "multibroadcast"; } + void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto ins = r.result; auto x_ins = r.instructions["x"]; - auto reshape_ins = r.instructions["reshape"]; auto input_ins = r.instructions["input"]; - const auto has_broadcast_before_reshape = is_broadcasted(reshape_ins, x_ins); - const auto has_broadcast_after_reshape = is_broadcasted(input_ins, reshape_ins); - if(not has_broadcast_before_reshape.has_value()) - return; - if(not has_broadcast_after_reshape.has_value()) + // If its just a broadcast then skip + if(not any_input_of(input_ins, x_ins, [](instruction_ref x) { + return not contains({"multibroadcast", "broadcast", "contiguous"}, x->name()); + })) return; - if(*has_broadcast_after_reshape and *has_broadcast_before_reshape) - return; - const bool has_broadcast = - *has_broadcast_after_reshape or *has_broadcast_before_reshape; auto dims1 = T::base_dims(ins); auto dims2 = T::base_dims(x_ins); @@ -144,41 +148,56 @@ struct rewrite_reshapes if(elements(dims1) != elements(dims2)) return; - auto cd = common_dims::compute(T::base_dims(ins), T::base_dims(x_ins)); - if(cd.dims.empty()) - return; + std::vector ops; + auto next_ins = input_ins; + while(next_ins != x_ins) + { + ops.push_back(next_ins->get_operator()); + next_ins = next_ins->inputs().front(); + } + assert(next_ins == x_ins); + std::reverse(ops.begin(), ops.end()); - if(ins->name() != "pointwise" and not T::supports(ins, cd.dims, cd.axes_map1)) + auto desc = + shape_transform_descriptor::create(x_ins->get_shape().lens(), ops).rebase(dims2); + if(desc.empty()) return; - if(x_ins->name() != "pointwise" and not T::supports(x_ins, cd.dims, cd.axes_map2)) - return; - - auto reshape_input = [&](const auto& ins_to_insert) { - return [&](auto input) { - auto dims = cd.get_dimensions_for(input->get_shape().lens()); - return mpm.get_module().insert_instruction( - ins_to_insert, make_op("reshape", {{"dims", dims}}), input); + auto cdims = desc.common_dims(); + auto reshape_input = [&](const auto& ins_to_insert, auto generate) { + return [&, generate](auto input) { + auto gops = std::invoke(generate, desc, input->get_shape().lens()); + auto start = input; + for(const auto& op : gops) + { + start = mpm.get_module().insert_instruction(ins_to_insert, op, start); + } + return start; }; }; auto x_inputs = x_ins->inputs(); std::transform( - x_inputs.begin(), x_inputs.end(), x_inputs.begin(), reshape_input(x_ins)); - auto new_x_ins = insert(mpm, x_ins, x_inputs, cd.axes_map2); - if(has_broadcast) + x_inputs.begin(), + x_inputs.end(), + x_inputs.begin(), + reshape_input(x_ins, &shape_transform_descriptor::generate_common_from_src)); + auto new_x_ins = insert(mpm, x_ins, x_inputs, desc.common_axes_map_from_src()); + if(new_x_ins->get_shape().lens() != cdims) { new_x_ins = mpm.get_module().insert_instruction( - x_ins, make_op("multibroadcast", {{"out_lens", cd.dims}}), new_x_ins); + x_ins, make_op("multibroadcast", {{"out_lens", cdims}}), new_x_ins); } auto inputs = ins->inputs(); std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { if(input == input_ins) return new_x_ins; - return reshape_input(ins)(input); + return reshape_input(ins, + &shape_transform_descriptor::generate_common_from_dst)(input); }); - auto pw = insert(mpm, ins, inputs, cd.axes_map1); - mpm.get_module().replace_instruction( - ins, make_op("reshape", {{"dims", ins->get_shape().lens()}}), pw); + auto pw = insert(mpm, ins, inputs, desc.common_axes_map_from_dst()); + auto rins = + reshape_input(ins, &shape_transform_descriptor::generate_dst_from_common)(pw); + mpm.get_module().replace_instruction(ins, rins); } static bool same_dims(instruction_ref ins) diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index e15c4dece44..65eb04d2b0a 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -64,8 +65,9 @@ struct MIGRAPHX_EXPORT shape m(uint64_type, uint64_t) \ m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz) \ m(fp8e4m3fn_type, migraphx::fp8::fp8e4m3fn) \ - m(fp8e5m2_type, migraphx::fp8::fp8e5m2) -// clang-format on + m(fp8e5m2_type, migraphx::fp8::fp8e5m2) \ + m(bf16_type, bf16) \ + m(fp8e5m2fnuz_type, migraphx::fp8::fp8e5m2fnuz) // clang-format on #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x, enum type_t diff --git a/src/include/migraphx/shape_transform_descriptor.hpp b/src/include/migraphx/shape_transform_descriptor.hpp index 4160759d47d..dd1ff256b9b 100644 --- a/src/include/migraphx/shape_transform_descriptor.hpp +++ b/src/include/migraphx/shape_transform_descriptor.hpp @@ -74,6 +74,11 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor shape_transform_descriptor() = default; explicit shape_transform_descriptor(const std::vector& dims); + static shape_transform_descriptor create(const std::vector& dims, + const std::vector& ops); + + shape_transform_descriptor rebase(const std::vector& dims) const; + bool apply(const std::vector& ops); bool apply_reshape(const std::vector& rdims); bool apply_reshape_impl(const std::vector& rdims); @@ -84,6 +89,22 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor std::size_t elements() const; std::vector generate() const; + bool has_broadcast() const; + void flatten_broadcast(); + + std::vector common_dims(const std::vector& input_dims = {}) const; + std::vector + generate_common_from_src(const std::vector& input_dims = {}) const; + std::vector + generate_common_from_dst(const std::vector& input_dims = {}) const; + std::vector + generate_dst_from_common(const std::vector& input_dims = {}) const; + std::vector> common_axes_map_from_src() const; + std::vector> common_axes_map_from_dst() const; + + bool empty() const; + std::vector lens() const; + struct MIGRAPHX_EXPORT dimension { void simplify(); @@ -98,7 +119,15 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor // the axis. However, it still needs to accounted for. After we // generate the broadcast we will set the axis to the hidden // axis, and then length to 1. - optional hidden_axis = nullopt; + std::vector hidden_axis = {}; + + const std::vector& origin_axis() const; + bool has_hidden_axis() const; + + void add_split_axis(std::size_t i); + + void expose(); + void hide(); MIGRAPHX_EXPORT friend bool operator==(const sub& x, const sub& y); MIGRAPHX_EXPORT friend bool operator!=(const sub& x, const sub& y); diff --git a/src/include/migraphx/quantize_fp16.hpp b/src/include/migraphx/truncate_float.hpp similarity index 82% rename from src/include/migraphx/quantize_fp16.hpp rename to src/include/migraphx/truncate_float.hpp index 7233fdf2e2e..426a445c02a 100644 --- a/src/include/migraphx/quantize_fp16.hpp +++ b/src/include/migraphx/truncate_float.hpp @@ -21,12 +21,13 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP -#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP +#ifndef MIGRAPHX_GUARD_RTGLIB_TRUNCATE_FLOAT_HPP +#define MIGRAPHX_GUARD_RTGLIB_TRUNCATE_FLOAT_HPP #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -35,12 +36,13 @@ struct program; struct module; /** - * quantize a program to fp16 + * quantize a program to fp */ -struct MIGRAPHX_EXPORT quantize_fp16_pass +struct MIGRAPHX_EXPORT truncate_float_pass { std::vector ins_names = {"all"}; - std::string name() const { return "quantize_fp16"; } + shape::type_t float_type = shape::float_type; + std::string name() const { return "truncate_float"; } void apply(module& m) const; }; diff --git a/src/include/migraphx/type_traits.hpp b/src/include/migraphx/type_traits.hpp index 908f67d9f14..ea42f83af9b 100644 --- a/src/include/migraphx/type_traits.hpp +++ b/src/include/migraphx/type_traits.hpp @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -53,10 +54,18 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, bf16) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, bf16) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, bf16) + MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e4m3fnuz) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e4m3fnuz) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e4m3fnuz) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e5m2fnuz) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e5m2fnuz) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e5m2fnuz) + MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e4m3fn) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e4m3fn) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e4m3fn) diff --git a/src/instruction.cpp b/src/instruction.cpp index 47bea70379e..235c551d709 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -26,7 +26,8 @@ #include #include #include -#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -58,22 +59,43 @@ instruction::instruction(literal l) { } +struct replace_shape_order +{ + instruction_ref start; + + std::size_t location(instruction_ref x) const { return std::distance(start, x); } + + bool operator()(instruction_ref x, instruction_ref y) const + { + return location(x) > location(y); + } +}; + void instruction::replace(const shape& r) { if(r != result) { result = r; - std::deque q(output.begin(), output.end()); + if(output.empty()) + { + return; + } + auto start = std::find_if(output.front()->inputs().begin(), + output.front()->inputs().end(), + [&](instruction_ref x) { return this == as_address(x); }); + assert(as_address(*start) == this); + std::priority_queue, replace_shape_order> q( + output.begin(), output.end(), replace_shape_order{*start}); while(not q.empty()) { - instruction_ref ins = q.front(); - q.pop_front(); + instruction_ref ins = q.top(); + q.pop(); assert(ins->name() == "@return" or ins->name().front() != '@'); shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args); if(new_r != ins->result) { ins->result = new_r; - std::copy(ins->output.begin(), ins->output.end(), std::back_inserter(q)); + std::copy(ins->output.begin(), ins->output.end(), migraphx::push_inserter(q)); } } } diff --git a/src/layout_nhwc.cpp b/src/layout_convolution.cpp similarity index 62% rename from src/layout_nhwc.cpp rename to src/layout_convolution.cpp index 9d2a0083a34..83acb839ce6 100644 --- a/src/layout_nhwc.cpp +++ b/src/layout_convolution.cpp @@ -21,7 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include +#include #include #include #include @@ -32,49 +32,61 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -template -std::vector find_lasts(const module& m, Predicate pred) +namespace { +std::vector get_permutation(instruction_ref ins, const layout_convolution& lc) { - std::vector result; - fix([&](auto self, auto ins) { - if(pred(ins)) - { - result.push_back(ins); - return; - } - for(auto input : ins->inputs()) - self(input); - })(std::prev(m.end())); - return result; + if(lc.channels_last) + { + std::vector perm(ins->get_shape().ndim()); + std::iota(perm.begin() + 1, perm.end() - 1, 2); + perm.back() = 1; + return perm; + } + return find_permutation(ins->inputs().front()->get_shape()); +} + +bool skip_layout(const shape& s) +{ + return s.ndim() == 1 or s.dynamic() or s.type() == shape::tuple_type; } void preserve_output_layout(module& m) { auto last = std::prev(m.end()); - std::vector outputs; if(last->name() == "@return") - outputs = last->inputs(); - else - outputs = {last}; - - for(auto output : outputs) { - auto permutation = find_permutation(output->get_shape()); - auto layout = m.insert_instruction( - std::next(output), make_op("layout", {{"permutation", permutation}}), output); - m.replace_instruction(output, layout); + std::vector outputs; + std::transform(last->inputs().begin(), + last->inputs().end(), + std::back_inserter(outputs), + [&](instruction_ref ins) { + if(skip_layout(ins->get_shape())) + return ins; + auto permutation = find_permutation(ins->get_shape()); + return m.insert_instruction( + last, make_op("layout", {{"permutation", permutation}}), ins); + }); + m.replace_return(outputs); + } + else if(not skip_layout(last->get_shape())) + { + auto permutation = find_permutation(last->get_shape()); + m.add_instruction(make_op("layout", {{"permutation", permutation}}), last); } } -void transform_convolutions(module& m) +void transform_convolutions(module& m, const layout_convolution& lc) { for(auto ins : iterator_for(m)) { - if(ins->name() != "convolution") + if(not contains({"convolution", "quant_convolution"}, ins->name())) + continue; + if(ins->get_shape().dynamic()) continue; if(ins->get_shape().lens().size() != 4) continue; @@ -82,8 +94,9 @@ void transform_convolutions(module& m) if(v.at("group").to() > 1) continue; auto args = ins->inputs(); + auto perm = get_permutation(ins, lc); std::transform(args.begin(), args.end(), args.begin(), [&](const auto& i) { - return m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i); + return m.insert_instruction(ins, make_op("layout", {{"permutation", perm}}), i); }); auto conv = m.insert_instruction(ins, ins->get_operator(), args); auto c = m.insert_instruction(ins, make_op("contiguous"), conv); @@ -102,11 +115,12 @@ void remove_layout(module& m) m.replace_instruction(ins, ins->inputs().front()); } } +} // namespace -void layout_nhwc::apply(module_pass_manager& mpm) const +void layout_convolution::apply(module_pass_manager& mpm) const { preserve_output_layout(mpm.get_module()); - transform_convolutions(mpm.get_module()); + transform_convolutions(mpm.get_module(), *this); mpm.run_pass(dead_code_elimination{}); mpm.run_pass(eliminate_contiguous{"contiguous"}); mpm.run_pass(dead_code_elimination{}); diff --git a/src/module.cpp b/src/module.cpp index 44148fd73fa..7e02478b385 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -355,7 +355,6 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref { impl->changed.notify(); assert(has_instruction(ins)); - assert(has_instruction(rep)); assert(ins != rep); if(ins == std::prev(this->end())) @@ -541,7 +540,6 @@ instruction_ref module::insert_parameter(instruction_ref ins, std::string name, instruction_ref module::replace_return(std::vector args) { impl->changed.notify(); - assert(std::all_of(args.begin(), args.end(), [&](auto ins) { return has_instruction(ins); })); auto last = std::prev(this->end()); // If there is no return then add a return if(last->name() != "@return") @@ -1124,7 +1122,7 @@ void module::debug_print(instruction_ref ins, std::cout << "Instruction not part of module" << std::endl; return; } - std::stringstream ss; + names = this->print( [&](auto x, auto ins_names) { if(x == ins) diff --git a/src/netron_output.cpp b/src/netron_output.cpp new file mode 100644 index 00000000000..64403dd7090 --- /dev/null +++ b/src/netron_output.cpp @@ -0,0 +1,266 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace { + +// from https://onnx.ai/onnx/intro/concepts.html +int get_onnx_type(shape::type_t s_type) +{ + switch(s_type) + { + case shape::float_type: return 1; + case shape::uint8_type: return 2; + case shape::int8_type: return 3; + case shape::uint16_type: return 4; + case shape::int16_type: return 5; + case shape::int32_type: return 6; + case shape::int64_type: return 7; + case shape::bool_type: return 9; + case shape::half_type: return 10; + case shape::double_type: return 11; + case shape::uint32_type: return 12; + case shape::uint64_type: return 13; + case shape::bf16_type: return 16; + case shape::fp8e4m3fn_type: return 17; + case shape::fp8e4m3fnuz_type: return 18; + case shape::fp8e5m2_type: return 19; + case shape::fp8e5m2fnuz_type: return 20; + case shape::tuple_type: return 0; + } + MIGRAPHX_THROW("MIGraphX type " + std::to_string(s_type) + " not supported"); +} + +auto make_attribute(const migraphx::value& val) +{ + value attribute; + attribute["name"] = val.get_key(); + auto val_string = val.to(); + val_string = val_string.substr(val_string.find(":") + 1); + attribute["s"] = base64_encode(val_string); + attribute["type"] = "STRING"; + return attribute; +} + +/// Returns a value with the JSON structure needed for a node +auto make_onnx_json_node(instruction_ref ins, + std::unordered_map ins_uids) +{ + value node; + // TODO add support for module inputs + value input_arr; + for(instruction_ref input_ins : ins->inputs()) + { + auto name = input_ins->name(); + if(name == "@literal" or name == "@param") + { + input_arr.push_back(ins_uids.at(input_ins)); + } + // TODO make a better process for handling nodes to ignore + else if(name.find("hip::hip_allocate_memory") != std::string::npos) + { + continue; + } + else + { + input_arr.push_back(ins_uids.at(input_ins) + "->" + ins_uids.at(ins)); + } + } + value output_arr; + for(instruction_ref output_ins : ins->outputs()) + { + if(output_ins->name() == "@return") + { + output_arr.push_back(ins_uids.at(output_ins)); + } + else + { + output_arr.push_back(ins_uids.at(ins) + "->" + ins_uids.at(output_ins)); + } + } + node["input"] = input_arr; + node["output"] = output_arr; + node["name"] = ins_uids.at(ins); + node["opType"] = ins->name(); + value op_attribute_arr; + auto op_value = ins->get_operator().to_value(); + std::for_each(op_value.begin(), op_value.end(), [&](auto v) { + const std::string& attr_key = v.get_key(); + if(v.is_binary()) + { + return; + } + else if(attr_key == "symbol_name" or attr_key == "name") + { + node["opType"] = migraphx::from_value(v); + } + else + { + op_attribute_arr.push_back(make_attribute(v)); + } + }); + node["attribute"] = op_attribute_arr; + return node; +} + +// ONNX graph constant data called "initializer" +auto make_onnx_json_literal(instruction_ref ins, + std::unordered_map ins_uids) +{ + value lit; + lit["dims"] = ins->get_shape().lens(); + lit["dataType"] = get_onnx_type(ins->get_shape().type()); + lit["name"] = ins_uids.at(ins); + // ignoring literal data, setting to "NULL" in base64 + lit["rawData"] = "TlVMTA=="; + return lit; +} + +// TODO handle dynamic shapes +// TODO handle subshapes +auto make_onnx_json_shape(const shape& s) +{ + value ret; + value dim; + auto shape_lens = s.lens(); + std::transform(shape_lens.begin(), + shape_lens.end(), + std::back_inserter(dim), + [](std::size_t len) { return len; }); + ret["dim"] = dim; + return ret; +} + +// ONNX graph edges called "valuetype" +auto make_onnx_json_edge(instruction_ref ins, + instruction_ref out_ins, + std::unordered_map ins_uids) +{ + value ret; + shape ins_shape = ins->get_shape(); + ret["name"] = ins_uids.at(ins) + "->" + ins_uids.at(out_ins); + value type = {{"tensorType", + {{"elemType", get_onnx_type(ins_shape.type())}, + {"shape", make_onnx_json_shape(ins_shape)}}}}; + ret["type"] = type; + return ret; +} + +auto make_onnx_json_in_out(instruction_ref ins, + std::unordered_map ins_uids) +{ + value ret; + shape ins_shape = ins->get_shape(); + ret["name"] = ins_uids.at(ins); + value type = {{"tensorType", + {{"elemType", get_onnx_type(ins_shape.type())}, + {"shape", make_onnx_json_shape(ins_shape)}}}}; + ret["type"] = type; + return ret; +} + +std::unordered_map make_ins_uids(const module& mod) +{ + std::unordered_map ret; + int count = 0; + for(auto ins : iterator_for(mod)) + { + std::string var_name; + var_name = mod.name() + ":"; + var_name.append(ins->name() + ":"); + var_name.append("@" + std::to_string(count)); + count++; + ret.emplace(ins, var_name); + } + return ret; +} + +value make_graph(const module* mod) +{ + value graph = { + {"node", {}}, {"initializer", {}}, {"input", {}}, {"output", {}}, {"valueInfo", {}}}; + auto ins_uids = make_ins_uids(*mod); + for(auto ins = mod->begin(); ins != mod->end(); ++ins) + { + const auto& name = ins->name(); + if(name == "@literal") + { + graph["initializer"].push_back(make_onnx_json_literal(ins, ins_uids)); + } + else if(name == "@param") + { + graph["input"].push_back(make_onnx_json_in_out(ins, ins_uids)); + } + else if(name == "@return") + { + graph["output"].push_back(make_onnx_json_in_out(ins, ins_uids)); + } + else if(name.find("hip::hip_allocate_memory") != std::string::npos) + { + continue; + } + else + { + graph["node"].push_back(make_onnx_json_node(ins, ins_uids)); + const auto& outputs = ins->outputs(); + for(auto out_ins : outputs) + { + if(out_ins->name() != "@return") + { + graph["valueInfo"].push_back(make_onnx_json_edge(ins, out_ins, ins_uids)); + } + } + } + } + return graph; +} + +} // namespace + +std::string make_netron_output(const program& prog) +{ + value output; + auto prog_value = prog.to_value(); + output["irVersion"] = prog_value.at("version").to(); + output["producerName"] = "AMDMIGraphX"; + output["producerVersion"] = prog_value.at("migraphx_version").to(); + for(auto& mod : prog.get_modules()) + { + auto graph = make_graph(mod); + output["graph"] = graph; + } + return to_json_string(output); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_matmul.cpp b/src/onnx/parse_matmul.cpp index 9ded26a4b8f..302dc6b943e 100644 --- a/src/onnx/parse_matmul.cpp +++ b/src/onnx/parse_matmul.cpp @@ -35,7 +35,9 @@ struct parse_matmul : op_parser { std::vector operators() const { - return {{"MatMul", "dot"}, {"MatMulInteger", "quant_dot"}}; + return {{"MatMul", "dot"}, + {"MatMulInteger", "quant_dot"}, + {"MatMulIntegerToFloat", "quant_dot_scaled"}}; } static void broadcast_dimensions(const onnx_parser::node_info& info, @@ -106,7 +108,82 @@ struct parse_matmul : op_parser return all_zeros; } - static instruction_ref set_bias_arg(const std::vector& args, + static instruction_ref set_scale_arg(const onnx_parser::node_info& info, + const std::vector& args, + const instruction_ref& mat_input, + const int index) + { + instruction_ref scale_arg = args[index]; + std::set supported_dq_types = {migraphx::shape::float_type, + migraphx::shape::half_type}; + + auto scale_shape = scale_arg->get_shape(); + + if(not(contains(supported_dq_types, scale_shape.type()))) + { + MIGRAPHX_THROW("PARSE_QUANT_DOT_SCALED: Scales must be float or half_type"); + } + + if(scale_shape.lens().at(0) != *(mat_input->get_shape().lens().rbegin()) and + not scale_shape.scalar()) + { + MIGRAPHX_THROW("PARSE_QUANT_DOT_SCALED: Scale must have same dim as matrix column"); + } + + if(scale_shape.lens().size() > 1 and not scale_shape.scalar()) + { + MIGRAPHX_THROW("PARSE_QUANT_DOT_SCALED: Scales shape must be scalar or 1-D tensor"); + } + + if(scale_shape.scalar()) + { + scale_arg = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), scale_arg); + scale_shape = scale_arg->get_shape(); + } + + scale_arg = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), scale_arg); + + return scale_arg; + } + + static instruction_ref set_scale_bias(const std::vector& args, + const int index, + const migraphx::shape& scale_arg_shape, + const instruction_ref& compare_arg, + bool& has_valid_scale_bias) + { + has_valid_scale_bias = false; + + if(args.size() > index) + { + instruction_ref scale_bias_arg = args[index]; + std::set supported_dq_types = {migraphx::shape::float_type, + migraphx::shape::half_type}; + + if(not(contains(supported_dq_types, scale_bias_arg->get_shape().type()))) + { + MIGRAPHX_THROW("PARSE_QUANT_DOT_SCALED: Bias must be float or half_type"); + } + + if(scale_bias_arg->get_shape().type() != scale_arg_shape.type()) + { + MIGRAPHX_THROW("PARSE_QUANT_DOT_SCALED: Bias must be the same type as scales"); + } + + if(scale_bias_arg->get_shape().lens().at(0) != + *(compare_arg->get_shape().lens().rbegin())) + { + MIGRAPHX_THROW("PARSE_QUANT_DOT_SCALED: Bias have same dim as matrix B column"); + } + + has_valid_scale_bias = true; + return scale_bias_arg; + } + return compare_arg; + } + + static instruction_ref set_bias_arg(const std::string& name, + const std::vector& args, const int index, const instruction_ref& input, bool& has_valid_bias) @@ -118,7 +195,7 @@ struct parse_matmul : op_parser instruction_ref bias_arg = args[index]; if(bias_arg->get_shape().type() != input->get_shape().type()) { - MIGRAPHX_THROW("PARSE_QUANT_DOT: zero point must be the same type as data"); + MIGRAPHX_THROW(name + ": zero point must be the same type as data"); } // Don't return zero point if it will cause symmetric zero point. No need to bias @@ -148,11 +225,124 @@ struct parse_matmul : op_parser } } + static void handle_scaled_transposes(const onnx_parser::node_info& info, + instruction_ref& scale, + instruction_ref& zp, + bool no_zp) + { + if(no_zp) + { + scale = info.add_instruction(make_op("transpose", {{"permutation", {0, 1}}}), scale); + } + else + { + scale = info.add_instruction(make_op("transpose", {{"permutation", {0, 1}}}), scale); + zp = info.add_instruction(make_op("transpose", {{"permutation", {1, 0}}}), zp); + } + } + + static instruction_ref handle_dequantized(const onnx_parser::node_info& info, + const instruction_ref& a0, + const instruction_ref& scale_a0, + const instruction_ref& zp_a0, + bool no_zp) + { + instruction_ref dequantized_op; + + if(no_zp) + { + auto bc_scale_a0 = info.add_instruction( + make_op("multibroadcast", {{"out_lens", a0->get_shape().lens()}}), scale_a0); + dequantized_op = info.add_instruction(make_op("dequantizelinear"), a0, bc_scale_a0); + } + else + { + auto bc_scale_a0 = info.add_instruction( + make_op("multibroadcast", {{"out_lens", a0->get_shape().lens()}}), scale_a0); + + auto bc_zp_a0 = info.add_instruction( + make_op("multibroadcast", {{"out_lens", a0->get_shape().lens()}}), zp_a0); + + dequantized_op = + info.add_instruction(make_op("dequantizelinear"), a0, bc_scale_a0, bc_zp_a0); + } + return dequantized_op; + } + + static instruction_ref handle_scaled_output(const onnx_parser::node_info& info, + const instruction_ref& a0, + const instruction_ref& a1, + const instruction_ref& scale_a0, + const instruction_ref& scale_a1, + const instruction_ref& zp_a0, + const instruction_ref& zp_a1, + const instruction_ref& scaled_bias, + const bool has_scale_bias) + { + + instruction_ref unsq_zp_a0; + instruction_ref unsq_zp_a1; + + bool a0_has_no_zp = (a0 == zp_a0); + bool a1_has_no_zp = (a1 == zp_a1); + + if(not a0_has_no_zp) + { + unsq_zp_a0 = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), zp_a0); + if(zp_a0->get_shape().scalar()) + { + unsq_zp_a0 = + info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), unsq_zp_a0); + } + } + + if(not a1_has_no_zp) + { + unsq_zp_a1 = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), zp_a1); + if(zp_a1->get_shape().scalar()) + { + unsq_zp_a1 = + info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), unsq_zp_a1); + } + } + + auto dq_a0 = handle_dequantized(info, a0, scale_a0, unsq_zp_a0, a0_has_no_zp); + auto dq_a1 = handle_dequantized(info, a1, scale_a1, unsq_zp_a1, a1_has_no_zp); + auto res = info.add_instruction(make_op("dot"), dq_a0, dq_a1); + + // Handle case of the bias after scaling + if(has_scale_bias) + res = info.add_common_op("sub", res, scaled_bias); + + return res; + } + + static void handle_uint8_input(const onnx_parser::node_info& info, + const bool has_bias, + const instruction_ref& offset_op, + instruction_ref& arg, + instruction_ref& bias_arg) + { + auto arg_type = arg->get_shape().type(); + // always convert uint8 to int8 to avoid rollover + if(arg_type == migraphx::shape::uint8_type) + { + shift_input_and_bias(info, offset_op, has_bias, arg, bias_arg); + } + + // subtract bias from result after conversion + if(has_bias) + { + bias_arg = info.add_common_op("sub", arg, bias_arg); + } + } + instruction_ref parse(const op_desc& opd, const onnx_parser& /*parser*/, const onnx_parser::node_info& info, std::vector args) const { + std::string op_name{opd.op_name}; auto a0 = args[0]; auto a1 = args[1]; auto s0 = a0->get_shape(); @@ -172,13 +362,17 @@ struct parse_matmul : op_parser a1 = info.add_instruction(make_op("unsqueeze", {{"axes", {1}}}), args[1]); } - auto is_quant_dot = opd.op_name == "quant_dot"; + auto is_quant_dot = opd.op_name == "quant_dot"; + auto is_quant_dot_scaled = opd.op_name == "quant_dot_scaled"; + auto is_dot = opd.op_name == "dot"; + if(s0.dynamic() or s1.dynamic()) { - if(is_quant_dot) + if(is_quant_dot or is_quant_dot_scaled) { - MIGRAPHX_THROW("PARSE_MATMUL: dynamic MatMulInteger not supported"); + MIGRAPHX_THROW(op_name + ": dynamic inputs not supported"); } + auto s0_dds = a0->get_shape().to_dynamic().dyn_dims(); auto s1_dds = a1->get_shape().to_dynamic().dyn_dims(); @@ -200,15 +394,44 @@ struct parse_matmul : op_parser auto s0_lens = a0->get_shape().lens(); auto s1_lens = a1->get_shape().lens(); - if(not is_quant_dot and args.size() > 2) + if(is_dot and args.size() > 2) { - MIGRAPHX_THROW("PARSE_MATMUL: Bias Args not supported for MatMul"); + MIGRAPHX_THROW(op_name + ": Bias Args not supported"); } bool has_ba0 = false; bool has_ba1 = false; - instruction_ref ba0 = set_bias_arg(args, 2, a0, has_ba0); - instruction_ref ba1 = set_bias_arg(args, 3, a1, has_ba1); + bool has_scale_bias = false; + + int a0_zp_index = 2; + int a1_zp_index = 3; + + instruction_ref scale_a0; + instruction_ref scale_a1; + // Handles case with for when scales are present in operator + if(is_quant_dot_scaled) + { + a0_zp_index = 4; + a1_zp_index = 5; + scale_a0 = set_scale_arg(info, args, a0, 2); + scale_a1 = set_scale_arg(info, args, a1, 3); + if(scale_a0->get_shape().type() != scale_a1->get_shape().type()) + { + MIGRAPHX_THROW(op_name + ": Scales must be the same type"); + } + } + + instruction_ref ba0 = set_bias_arg(op_name, args, a0_zp_index, a0, has_ba0); + instruction_ref ba1 = set_bias_arg(op_name, args, a1_zp_index, a1, has_ba1); + + // handle optional bias arg to the result + instruction_ref scaled_bias; + if(is_quant_dot_scaled) + { + auto scaled_index = 6; + scaled_bias = + set_scale_bias(args, scaled_index, scale_a1->get_shape(), a1, has_scale_bias); + } // Only INT8 or UINT8 type currently supported std::set supported_types = {migraphx::shape::uint8_type, @@ -216,45 +439,35 @@ struct parse_matmul : op_parser const auto a0_type = a0->get_shape().type(); const auto a1_type = a1->get_shape().type(); - if(is_quant_dot and + if((not is_dot) and (not contains(supported_types, a0_type) or not contains(supported_types, a1_type))) { - MIGRAPHX_THROW("PARSE_MATMULINTEGER: Unsupported type"); + MIGRAPHX_THROW(op_name + ": Unsupported type"); } - instruction_ref offset_op; - if(is_quant_dot and ((a0_type == migraphx::shape::uint8_type) or - (a1_type == migraphx::shape::uint8_type))) + if((is_quant_dot and ((a0_type == migraphx::shape::uint8_type) or + (a1_type == migraphx::shape::uint8_type)))) { - offset_op = info.add_literal( + auto offset_op = info.add_literal( migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {-128}}); + handle_uint8_input(info, has_ba0, offset_op, a0, ba0); + handle_uint8_input(info, has_ba1, offset_op, a1, ba1); } - // always convert uint8 to int8 to avoid rollover - if(is_quant_dot and (a0_type == migraphx::shape::uint8_type)) - { - shift_input_and_bias(info, offset_op, has_ba0, a0, ba0); - } - - if(is_quant_dot and (a1_type == migraphx::shape::uint8_type)) - { - shift_input_and_bias(info, offset_op, has_ba1, a1, ba1); - } + broadcast_dimensions(info, s0_lens, s1_lens, a0, a1, ba0, ba1); - // subtract bias from result after conversion - if(is_quant_dot and has_ba0) + // Apply the scale to dequantize input to then perform a simple dot + // after the zero points are applied otherwise get a int32 output from the quantized + // equivalent. Ensure these are broadcasted accordingly before we perform a dot + if(is_quant_dot_scaled) { - ba0 = info.add_common_op("sub", a0, ba0); + dot_res = handle_scaled_output( + info, a0, a1, scale_a0, scale_a1, ba0, ba1, scaled_bias, has_scale_bias); } - - if(is_quant_dot and has_ba1) + else { - ba1 = info.add_common_op("sub", a1, ba1); + dot_res = info.add_instruction(make_op(opd.op_name), ba0, ba1); } - - broadcast_dimensions(info, s0_lens, s1_lens, a0, a1, ba0, ba1); - - dot_res = info.add_instruction(make_op(opd.op_name), ba0, ba1); } // squeeze the appended or prepended dimensions diff --git a/src/onnx/parse_matmulnbits.cpp b/src/onnx/parse_matmulnbits.cpp index af9f09790aa..e5fb336dfad 100644 --- a/src/onnx/parse_matmulnbits.cpp +++ b/src/onnx/parse_matmulnbits.cpp @@ -66,10 +66,10 @@ struct parse_matmulnbits : op_parser to_string_range(expected_b_lens) + ". Actual dims: " + to_string_range(args[1]->get_shape().lens())); - std::vector expected_scales_lens{n * n_blocks_per_col}; - if(args[2]->get_shape().lens() != expected_scales_lens) + const size_t expected_scales_lens = n * n_blocks_per_col; + if(args[2]->get_shape().elements() != expected_scales_lens) MIGRAPHX_THROW("MatMulNBits: Input scales does not match expected dims: " + - to_string_range(expected_scales_lens) + + to_string(expected_scales_lens) + ". Actual dims: " + to_string_range(args[2]->get_shape().lens())); if(args.size() > 3) diff --git a/src/program.cpp b/src/program.cpp index cac833803b3..2d43f3f8d55 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -44,6 +44,7 @@ #include #include #include +#include #include #include #include @@ -845,6 +846,31 @@ double common_average(const std::vector& v) return total / std::distance(v.begin() + n, v.end() - n); } +double mean(const std::vector& v) +{ + double total = std::accumulate(v.begin(), v.end(), 0.0); + return total / v.size(); +} + +double median(const std::vector& v) +{ + size_t mid = v.size() / 2; + if(v.size() % 2 == 0) + { + return (v[mid - 1] + v[mid]) / 2.0; + } + else + { + return v[mid]; + } +} + +double percentile(const std::vector& v, double percentile) +{ + size_t index = (percentile * (v.size() - 1)); + return v[index]; +} + std::string perf_group(instruction_ref ins, bool detailed) { std::string result; @@ -925,8 +951,14 @@ void program::perf_report( { overhead_vec.push_back(time([&] { dry_run(params); })); } - double total_time = common_average(total_vec); + double min_time = total_vec.front(); + double max_time = total_vec.back(); + double mean_time = mean(total_vec); + double median_time = median(total_vec); + double percentile_90_time = percentile(total_vec, 0.90); + double percentile_95_time = percentile(total_vec, 0.95); + double percentile_99_time = percentile(total_vec, 0.99); double rate = 1000.0 / total_time; double overhead_time = common_average(overhead_vec); double overhead_percent = overhead_time * 100.0 / total_time; @@ -978,7 +1010,14 @@ void program::perf_report( os << "Batch size: " << batch << std::endl; os << "Rate: " << rate * batch << " inferences/sec" << std::endl; - os << "Total time: " << total_time << "ms" << std::endl; + os << "Total time: " << total_time << "ms "; + os << "(Min: " << min_time << "ms, "; + os << "Max: " << max_time << "ms, "; + os << "Mean: " << mean_time << "ms, "; + os << "Median: " << median_time << "ms)" << std::endl; + os << "Percentiles (90%, 95%, 99%): ("; + os << percentile_90_time << "ms, " << percentile_95_time << "ms, " << percentile_99_time + << "ms)" << std::endl; os << "Total instructions time: " << total_instruction_time << "ms" << std::endl; os << "Overhead time: " << overhead_time << "ms" << ", " << calculate_overhead_time << "ms" << std::endl; @@ -1005,7 +1044,6 @@ void program::debug_print(instruction_ref ins) const return; } - std::stringstream ss; this->print(names, [&](auto x, auto ins_names) { if(x == ins) { diff --git a/src/py/backend/backend.py b/src/py/backend/backend.py index 4ad6cc2305d..e2b32a5c984 100755 --- a/src/py/backend/backend.py +++ b/src/py/backend/backend.py @@ -110,7 +110,8 @@ def prepare(cls, model, device=None, **kwargs): "Incompatible device expected '{0}', got '{1}'".format( device, get_device())) inf = migraphx.parse_onnx_buffer(model) - cls._prog_string = str("\nProgram =\n{}".format(inf)) + cls._prog_string = str("\nPython =\n{}\nProgram =\n{}".format( + inf.to_py(), inf)) device = cls._device cls._input_names = inf.get_parameter_names() inf.compile(migraphx.get_target(device.lower())) diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index 04daa5e35a3..75f7fab09d9 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -48,7 +48,7 @@ #include #endif -using half = half_float::half; +using half = migraphx::half; namespace py = pybind11; #ifdef __clang__ @@ -158,6 +158,17 @@ struct npy_format_descriptor static constexpr auto name() { return _("fp8e4m3fnuz"); } }; +template <> +struct npy_format_descriptor +{ + static std::string format() + { + // TODO: no standard format in numpy for fp8 + return "z"; + } + static constexpr auto name() { return _("fp8e5m2fnuz"); } +}; + template <> struct npy_format_descriptor { @@ -180,6 +191,17 @@ struct npy_format_descriptor static constexpr auto name() { return _("fp8e5m2"); } }; +template <> +struct npy_format_descriptor +{ + static std::string format() + { + // TODO: no standard format in numpy for bf16 + return "z"; + } + static constexpr auto name() { return _("bf16"); } +}; + } // namespace detail } // namespace pybind11 @@ -241,6 +263,13 @@ migraphx::shape to_shape(const py::buffer_info& info) { migraphx::shape::type_t t; std::size_t n = 0; + // Unsupported pybuffer types lead to undefined behaviour when comparing with migraphx type enum + if(info.format == "z") + { + MIGRAPHX_THROW( + "MIGRAPHX PYTHON: Unsupported data type. For fp8 and bf16 literals try using " + "migraphx.generate_argument with migraphx.add_literal"); + } visit_types([&](auto as) { if(info.format == py::format_descriptor::format() or (info.format == "l" and py::format_descriptor::format() == "q") or @@ -366,6 +395,12 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) py::arg("op"), py::arg("args"), py::arg("mod_args") = std::vector{}) + .def( + "add_literal", + [](migraphx::module& mm, migraphx::argument a) { + return mm.add_literal(a.get_shape(), a.data()); + }, + py::arg("data")) .def( "add_literal", [](migraphx::module& mm, py::buffer data) { @@ -446,6 +481,12 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) migraphx::any_ptr(reinterpret_cast(stream), stream_name), true}; return p.eval(pm, exec_env); }) + .def("to_py", + [](const migraphx::program& p) { + std::stringstream ss; + p.print_py(ss); + return ss.str(); + }) .def("sort", &migraphx::program::sort) .def("print", [](const migraphx::program& p) { std::cout << p << std::endl; }) .def("__eq__", std::equal_to{}) @@ -623,6 +664,10 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) }, "Auto-convert FP8 parameters and return values to Float for MIGraphX Program", py::arg("prog")); + m.def("quantize_bf16", + &migraphx::quantize_bf16, + py::arg("prog"), + py::arg("ins_names") = std::vector{"all"}); #ifdef HAVE_GPU m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); diff --git a/src/quantization.cpp b/src/quantization.cpp index a9b47d1d503..276012bbf73 100644 --- a/src/quantization.cpp +++ b/src/quantization.cpp @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include #include @@ -69,7 +69,17 @@ void quantize_fp16(program& prog, const std::vector& ins_names) run_passes(prog, {normalize_ops{}, optimize_module{{"quantizelinear", "dequantizelinear"}}, - quantize_fp16_pass{ins_names}, + truncate_float_pass{ins_names, shape::half_type}, + optimize_module{{"quantizelinear", "dequantizelinear"}}}, + quant_tracer()); +} + +void quantize_bf16(program& prog, const std::vector& ins_names) +{ + run_passes(prog, + {normalize_ops{}, + optimize_module{{"quantizelinear", "dequantizelinear"}}, + truncate_float_pass{ins_names, shape::bf16_type}, optimize_module{{"quantizelinear", "dequantizelinear"}}}, quant_tracer()); } diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 51802ef0fae..987cc891573 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -28,10 +28,12 @@ #include #include #include +#include #include #include #include #include +#include #include namespace migraphx { @@ -84,6 +86,41 @@ static std::vector get_all_subdimensions(const std::vector +static void for_each_subdimension(Dimensions&& dimensions, Range&& r, F f) +{ + auto start = r.begin(); + auto last = r.end(); + for(auto& dim : dimensions) + { + for(auto& s : dim.subdimensions) + { + if(start == last) + return; + f(s, *start); + start++; + } + } +} + +// Group all axes into a map with a key of the axis and the value is vector of +// all subdimensions that have that axis. +static std::map> +group_axes(std::vector& dimensions) +{ + std::map> axes_map; + for(auto& d : dimensions) + { + for(auto& s : d.subdimensions) + { + if(s.origin_axis().empty()) + continue; + axes_map[s.origin_axis().front()].push_back(&s); + } + } + return axes_map; +} + std::vector compute_dims(const operation& op, const std::vector& idims) { shape s{shape::float_type, idims}; @@ -99,6 +136,66 @@ std::vector compute_dims(const std::vector& ops, return s.lens(); } +shape_transform_descriptor shape_transform_descriptor::create(const std::vector& dims, + const std::vector& ops) +{ + shape_transform_descriptor result{dims}; + if(not result.apply(ops)) + return {}; + result.simplify(); + assert(compute_dims(ops, dims) == compute_dims(result.generate(), dims)); + return result; +} + +shape_transform_descriptor +shape_transform_descriptor::rebase(const std::vector& dims) const +{ + auto result = *this; + auto axes_map = group_axes(result.dimensions); + for(auto& [axis, subs] : axes_map) + { + assert(axis < dims.size()); + auto dim = dims[axis]; + auto final_dim = transform_accumulate(subs.begin(), + subs.end(), + std::size_t{1}, + std::multiplies<>{}, + [](const dimension::sub* s) { return s->len; }); + if(dim == final_dim) + { + for(auto* sub : subs) + sub->expose(); + } + else if(dim == 1) + { + for(auto* sub : subs) + { + if(not sub->has_hidden_axis()) + sub->len = 1; + } + } + else if(subs.size() == 1) + { + subs.front()->len = dim; + subs.front()->expose(); + } + else + MIGRAPHX_THROW("Invalid rebase"); + } + result.simplify(); + + return result; +} +static dimension::sub* get_last_subdimension(std::vector& dims) +{ + if(dims.empty()) + return {}; + auto& d = dims.back(); + if(d.subdimensions.empty()) + return nullptr; + return &d.subdimensions.back(); +} + bool shape_transform_descriptor::apply(const std::vector& ops) { std::vector dims; @@ -196,8 +293,7 @@ bool shape_transform_descriptor::apply_reshape_impl(const std::vector dimension { auto new_sub = sub; - if(not new_sub.axis.empty()) - new_sub.axis.push_back(j); + new_sub.add_split_axis(j); new_sub.len = start[j]; return {{new_sub}}; }); @@ -209,12 +305,20 @@ bool shape_transform_descriptor::apply_reshape_impl(const std::vector{} : sub->axis; + auto trailing_dims = range(rdims.begin() + new_dims.size(), rdims.end()); + if(any_of(trailing_dims, [](auto d) { return d != 1; })) + return false; + if(distance(trailing_dims) > 1) + sub->add_split_axis(0); + transform(range(distance(trailing_dims)), + std::back_inserter(new_dims), + [&](std::size_t j) -> dimension { + dimension::sub s{1, axis}; + s.add_split_axis(j + 1); + return {{s}}; + }); } assert(rdims.size() == new_dims.size()); if(rdims.size() != new_dims.size()) @@ -252,7 +356,16 @@ bool shape_transform_descriptor::apply_broadcast(const std::vector& return dim; if(dim.len() != 1) MIGRAPHX_THROW("Wrong out_lens for broadcast"); - return {{dimension::sub{len, {}}}}; + auto new_subs = dim.subdimensions; + if(not new_subs.empty()) + { + new_subs.front().len = len; + } + for(auto& s : new_subs) + { + s.hide(); + } + return {new_subs}; }); std::transform(out_lens.begin() + offset + dimensions.size(), out_lens.end(), @@ -281,14 +394,19 @@ void dimension::simplify() remove_1_sub_dims(subdimensions); // Flatten adjacent dimensions adjacent_for_each(subdimensions.begin(), subdimensions.end(), [&](sub& d1, sub& d2) { - if(d1.axis.size() < 2) + if(d1.origin_axis().size() < 2) + return; + if(d2.origin_axis().size() < 2) return; - if(d2.axis.size() < 2) + if(d1.has_hidden_axis() != d2.has_hidden_axis()) return; - if(not std::equal(d1.axis.begin(), d1.axis.end() - 1, d2.axis.begin(), d2.axis.end() - 1)) + if(not std::equal(d1.origin_axis().begin(), + d1.origin_axis().end() - 1, + d2.origin_axis().begin(), + d2.origin_axis().end() - 1)) return; - auto a1 = d1.axis.back(); - auto a2 = d2.axis.back(); + auto a1 = d1.origin_axis().back(); + auto a2 = d2.origin_axis().back(); assert(a2 != a1); if(a2 <= a1) return; @@ -347,7 +465,7 @@ static bool missing_leading_axis(const dimension& d) if(d.subdimensions.empty()) return true; const auto& sub = d.subdimensions.front(); - return sub.axis.empty(); + return sub.origin_axis().empty(); } static void set_broadcast_dim(dimension& d, std::size_t axis) @@ -355,25 +473,73 @@ static void set_broadcast_dim(dimension& d, std::size_t axis) if(d.subdimensions.empty()) d.subdimensions.push_back({1, {axis}}); else - d.subdimensions.front().hidden_axis = axis; + { + assert(d.subdimensions.front().hidden_axis.empty()); + d.subdimensions.front().hidden_axis = {axis}; + } } -// Group all axes into a map with a key of the axis and the value is vector of -// all subdimensions that have that axis. -static std::map> -group_axes(std::vector& dimensions) +static void set_origin_axis(dimension::sub& s, const std::vector& axis) { - std::map> axes_map; - for(auto& d : dimensions) + if(s.has_hidden_axis()) + s.hidden_axis = axis; + else + s.axis = axis; +} + +// If an axis is split and some dimensions are hidden and others are not, then +// remove the hidden axis so only the non-hidden axis is used in +// simplificaiton +static void remove_split_hidden_axes(std::map>& axes_map) +{ + for(auto&& p : axes_map) { - for(auto& s : d.subdimensions) + auto& subs = p.second; + if(std::all_of(subs.begin(), subs.end(), [](const dimension::sub* s) { + return s->has_hidden_axis(); + })) + continue; + for(auto* sub : subs) { - if(s.axis.empty()) + if(not sub->has_hidden_axis()) continue; - axes_map[s.axis.front()].push_back(&s); + sub->hidden_axis.clear(); } + // Remove the subdimesions that no longer have an axis + subs.erase(std::remove_if(subs.begin(), + subs.end(), + [](const dimension::sub* s) { + return s->axis.empty() and s->hidden_axis.empty(); + }), + subs.end()); + } + // Remove axis from group if empty + erase_if(axes_map, [](auto&& p) { return p.second.empty(); }); +} + +// If this is scalar, then remove all axes +static void remove_scalar_axis(std::vector& dimensions) +{ + dimension::sub* s = nullptr; + for(auto& d : dimensions) + { + auto has_axis = [](const dimension::sub& x) { return not x.origin_axis().empty(); }; + auto it = std::find_if(d.subdimensions.begin(), d.subdimensions.end(), has_axis); + if(it == d.subdimensions.end()) + continue; + if(s != nullptr) + return; + if(std::count_if(std::next(it), d.subdimensions.end(), has_axis) > 0) + return; + s = &*it; + } + if(s != nullptr) + { + if(s->has_hidden_axis()) + s->hidden_axis.clear(); + if(s->len == 1) + s->axis.clear(); } - return axes_map; } // Renumber all axes while preserving the order of the axes @@ -385,15 +551,15 @@ static void renumber_axes(std::map>& a auto& subs = p.second; if(subs.size() == 1) { - subs[0]->axis = {axis}; + set_origin_axis(*subs[0], {axis}); } else { std::sort(subs.begin(), subs.end(), by(std::less<>{}, [](const dimension::sub* s) { - return s->axis; + return s->origin_axis(); })); for(std::size_t i : range(subs.size())) - subs[i]->axis = {axis, i}; + set_origin_axis(*subs[i], {axis, i}); } } } @@ -437,6 +603,8 @@ void shape_transform_descriptor::simplify() for(auto& d : dimensions) d.simplify(); + remove_scalar_axis(dimensions); + std::map missing_axes; std::vector last_axis; { @@ -445,6 +613,7 @@ void shape_transform_descriptor::simplify() if(axes_map.empty()) return; + remove_split_hidden_axes(axes_map); renumber_axes(axes_map); // Find last axis @@ -471,8 +640,8 @@ void shape_transform_descriptor::simplify() { assert(not last->subdimensions.empty()); const auto& sub = last->subdimensions.front(); - assert(not sub.axis.empty()); - axis = sub.axis.front(); + assert(not sub.origin_axis().empty()); + axis = sub.origin_axis().front(); } std::deque dims(std::distance(start, last)); std::iota(dims.begin(), dims.end(), std::distance(dimensions.begin(), start)); @@ -518,18 +687,18 @@ void shape_transform_descriptor::simplify() // Search for the subdimension that has the next axis and try to // insert the axis before it will be in order. auto [sub, it, prev] = find_subdimension(*this, [&](const dimension::sub& s) { - if(s.axis.empty()) + if(s.origin_axis().empty()) return false; - if(s.axis.front() != next_axis) + if(s.origin_axis().front() != next_axis) return false; - if(s.axis.size() == 1) + if(s.origin_axis().size() == 1) return true; - assert(s.axis.size() == 2); - return s.axis.back() == 0; + assert(s.origin_axis().size() == 2); + return s.origin_axis().back() == 0; }); bool in_order = false; - if(prev.has_value() and not(*prev)->axis.empty()) - in_order = (*prev)->axis.front() == missing_axis - 1; + if(prev.has_value() and not(*prev)->origin_axis().empty()) + in_order = (*prev)->origin_axis().front() == missing_axis - 1; // If the axis is not inorder then see if we can find a broadcast axis to place it auto bdims = in_order ? broadcast_dims_map.end() : broadcast_dims_map.upper_bound(missing_axis); @@ -549,6 +718,22 @@ void shape_transform_descriptor::simplify() collapse_1_dims(dimensions); } +static std::size_t get_len(const dimension::sub& s, const std::vector& input_dims) +{ + if(input_dims.empty()) + return s.len; + if(s.axis.empty()) + return s.len; + auto dim = input_dims.at(s.axis.front()); + if(dim == 0) + return s.len; + if(dim == 1) + return 1; + if(s.axis.size() == 1) + return dim; + return s.len; +} + static operation make_reshape_squeeze(const std::vector& new_dims) { // Can use squeeze @@ -611,17 +796,15 @@ static void flatten_broadcasted_dim(dimension::sub& s) if(s.axis.empty()) { s.len = 1; - if(s.hidden_axis.has_value()) - { - s.axis = {s.hidden_axis.value()}; - s.hidden_axis = nullopt; - } + s.expose(); } } -static operation make_reshape_unsqueeze(const std::vector& subs) +static operation make_reshape_unsqueeze(const std::vector& subs, + const std::vector& input_dims = {}) { bool use_reshape = false; + std::unordered_set all_1s; // Check if split dimensions are all additional 1s if(std::any_of( subs.begin(), subs.end(), [](const dimension::sub& s) { return s.axis.size() > 1; })) @@ -643,8 +826,11 @@ static operation make_reshape_unsqueeze(const std::vector& subs) if(n < 2) return; // Number of elements that are 1 - auto n1 = - std::count_if(start, last, [](const dimension::sub& s) { return s.len == 1; }); + auto n1 = std::count_if(start, last, [&](const dimension::sub& s) { + return get_len(s, input_dims) == 1; + }); + if(n == n1 and not start->axis.empty()) + all_1s.insert(start->axis.front()); use_reshape |= std::max(0, n - n1 - 1) > 0; }, by_axis); @@ -655,10 +841,10 @@ static operation make_reshape_unsqueeze(const std::vector& subs) std::transform(subs.begin(), subs.end(), std::back_inserter(dims), - [](const dimension::sub& s) -> std::size_t { + [&](const dimension::sub& s) -> std::size_t { if(s.axis.empty()) return 1; - return s.len; + return get_len(s, input_dims); }); return make_op("reshape", {{"dims", dims}}); } @@ -670,7 +856,9 @@ static operation make_reshape_unsqueeze(const std::vector& subs) const auto& sub = subs[i]; if(sub.axis.size() == 1) continue; - if(sub.len != 1 and not sub.axis.empty()) + if(get_len(sub, input_dims) != 1 and not sub.axis.empty()) + continue; + if(not sub.axis.empty() and contains(all_1s, sub.axis.front()) and sub.axis.back() == 0) continue; axes.push_back(i); } @@ -678,10 +866,26 @@ static operation make_reshape_unsqueeze(const std::vector& subs) } } +namespace { +struct operation_list +{ + std::vector ops; + + void push_back(const operation& op) { ops.push_back(op); } + + std::vector to_vector() && + { + std::reverse(ops.begin(), ops.end()); + return std::move(ops); + } +}; + +} // namespace + static bool has_no_axes(const dimension& d) { return std::all_of(d.subdimensions.begin(), d.subdimensions.end(), [](const dimension::sub& s) { - return s.axis.empty() and not s.hidden_axis.has_value(); + return s.axis.empty() and s.hidden_axis.empty(); }); } static bool has_axes(const dimension& d) @@ -691,6 +895,59 @@ static bool has_axes(const dimension& d) }); } +static void generate_from_subdimensions(operation_list& result, + std::vector subs, + const std::vector& input_dims = {}) +{ + // Need multibroadcast + if(std::any_of(subs.begin(), subs.end(), [&](const dimension::sub& s) { + return s.axis.empty() and get_len(s, input_dims) != 1; + })) + { + std::vector out_lens; + std::transform(subs.begin(), + subs.end(), + std::back_inserter(out_lens), + [&](const dimension::sub& s) { return get_len(s, input_dims); }); + result.push_back(make_op("multibroadcast", {{"out_lens", out_lens}})); + } + + // Flatten broadcasted subdimensions + std::for_each(subs.begin(), subs.end(), &flatten_broadcasted_dim); + + auto tsubs = subs; + // Inject additonal axis to compute transpose permutation better + auto is_empty_axis = [](const auto& s) { return s.axis.empty(); }; + group_find(tsubs.begin(), tsubs.end(), is_empty_axis, [&](auto start, auto last) { + if(start == tsubs.begin()) + return; + auto base = std::prev(start); + auto axis = base->axis; + axis.push_back(0); + std::for_each(start, last, [&](auto& s) { + s.axis = axis; + axis.back()++; + }); + }); + + auto compare_sub = [](auto f) { + return by(f, [](const dimension::sub& s) -> const auto& { return s.axis; }); + }; + // Need transpose + if(not std::is_sorted(tsubs.begin(), tsubs.end(), compare_sub(std::less<>{}))) + { + auto permutation = sort_permutation(tsubs, compare_sub(std::less<>{})); + result.push_back(make_op("transpose", {{"permutation", invert_permutation(permutation)}})); + subs = reorder_dims(subs, permutation); + } + // Need reshape unsqueeze + if(std::any_of( + subs.begin(), subs.end(), [](const dimension::sub& s) { return s.axis.size() != 1; })) + { + result.push_back(make_reshape_unsqueeze(subs, input_dims)); + } +} + // This will generate the operators to apply the shape transformation that is // represented by this class. This is the order of operators that will be // generated if needed: @@ -706,7 +963,7 @@ static bool has_axes(const dimension& d) // dimensions. std::vector shape_transform_descriptor::generate() const { - std::vector result; + operation_list result; std::vector new_dims = dimensions; // Need broadcast if(std::any_of(new_dims.begin(), new_dims.end(), &is_broadcast_dim)) @@ -755,54 +1012,136 @@ std::vector shape_transform_descriptor::generate() const } auto subs = get_all_subdimensions(new_dims); - // Need multibroadcast - if(std::any_of(subs.begin(), subs.end(), [](const dimension::sub& s) { - return s.axis.empty() and s.len != 1; + generate_from_subdimensions(result, subs); + return std::move(result).to_vector(); +} + +bool shape_transform_descriptor::has_broadcast() const +{ + return std::any_of(dimensions.begin(), dimensions.end(), [&](const dimension& d) { + return std::any_of(d.subdimensions.begin(), + d.subdimensions.end(), + [&](const dimension::sub& s) { return s.axis.empty() and s.len != 1; }); + }); +} +void shape_transform_descriptor::flatten_broadcast() +{ + for(auto& d : dimensions) + std::for_each(d.subdimensions.begin(), d.subdimensions.end(), &flatten_broadcasted_dim); +} + +std::vector shape_transform_descriptor::generate_common_from_src( + const std::vector& input_dims) const +{ + operation_list result; + auto subs = get_all_subdimensions(dimensions); + generate_from_subdimensions(result, subs, input_dims); + return std::move(result).to_vector(); +} +std::vector shape_transform_descriptor::generate_common_from_dst( + const std::vector& input_dims) const +{ + // Need reshape + if(std::all_of(dimensions.begin(), dimensions.end(), [](const dimension& d) { + return d.subdimensions.size() == 1; })) + return {}; + std::vector subs; + // Update axes to point to the destination + for(std::size_t i : range(dimensions.size())) { - std::vector out_lens; - std::transform(subs.begin(), - subs.end(), - std::back_inserter(out_lens), - [](const dimension::sub& s) { return s.len; }); - result.push_back(make_op("multibroadcast", {{"out_lens", out_lens}})); + const auto& d = dimensions[i]; + std::transform(d.subdimensions.begin(), + d.subdimensions.end(), + range(d.subdimensions.size()).begin(), + std::back_inserter(subs), + [&](dimension::sub s, auto j) { + s.axis = {i}; + if(d.subdimensions.size() > 1) + s.axis.push_back(j); + return s; + }); } + return {make_reshape_unsqueeze(subs, input_dims)}; +} +std::vector shape_transform_descriptor::generate_dst_from_common( + const std::vector& input_dims) const +{ + std::vector result; + std::vector new_dims = dimensions; + for_each_subdimension(new_dims, input_dims, [&](auto& s, auto dim) { s.len = dim; }); - // Flatten broadcasted subdimensions - std::for_each(subs.begin(), subs.end(), &flatten_broadcasted_dim); - - auto tsubs = subs; - // Inject additonal axis to compute transpose permutation better - auto is_empty_axis = [](const auto& s) { return s.axis.empty(); }; - group_find(tsubs.begin(), tsubs.end(), is_empty_axis, [&](auto start, auto last) { - if(start == tsubs.begin()) - return; - auto base = std::prev(start); - auto axis = base->axis; - axis.push_back(0); - std::for_each(start, last, [&](auto& s) { - s.axis = axis; - axis.back()++; - }); - }); + // Remove broadcasted dimensions + for(auto& d : new_dims) + { + if(d.subdimensions.size() != 1) + continue; + auto& s = d.subdimensions.front(); + s.expose(); + } + // Need squeeze reshape + if(std::any_of(new_dims.begin(), new_dims.end(), [](const dimension& d) { + if(d.subdimensions.size() != 1) + return true; + return is_broadcast_dim(d); + })) + { + result.push_back(make_reshape_squeeze(new_dims)); + } + return result; +} - auto compare_sub = [](auto f) { - return by(f, [](const dimension::sub& s) -> const auto& { return s.axis; }); - }; - // Need transpose - if(not std::is_sorted(tsubs.begin(), tsubs.end(), compare_sub(std::less<>{}))) +std::vector> shape_transform_descriptor::common_axes_map_from_src() const +{ + std::vector> result; + auto subs = get_all_subdimensions(dimensions); + std::map> axes_map; + for(const auto& s : subs) { - auto permutation = sort_permutation(tsubs, compare_sub(std::less<>{})); - result.push_back(make_op("transpose", {{"permutation", invert_permutation(permutation)}})); - subs = reorder_dims(subs, permutation); + if(not s.origin_axis().empty()) + axes_map[s.origin_axis().front()].push_back(&s); } - // Need reshape unsqueeze - if(std::any_of( - subs.begin(), subs.end(), [](const dimension::sub& s) { return s.axis.size() != 1; })) + for(auto&& p : axes_map) + { + std::sort(p.second.begin(), p.second.end(), by(std::less<>{}, [](const dimension::sub* s) { + return s->axis; + })); + } + assert(not axes_map.empty()); + auto max_axis = std::prev(axes_map.end())->first; + result.resize(max_axis + 1); + for(auto&& p : axes_map) + { + assert(p.first < result.size()); + std::transform(p.second.begin(), + p.second.end(), + std::back_inserter(result[p.first]), + [&](const dimension::sub* s) { return s - subs.data(); }); + } + return result; +} +std::vector> shape_transform_descriptor::common_axes_map_from_dst() const +{ + std::vector> result; + std::size_t start = 0; + for(const auto& d : dimensions) { - result.push_back(make_reshape_unsqueeze(subs)); + auto& v = result.emplace_back(d.subdimensions.size()); + std::iota(v.begin(), v.end(), start); + start += d.subdimensions.size(); } - std::reverse(result.begin(), result.end()); + return result; +} + +bool shape_transform_descriptor::empty() const { return dimensions.empty(); } + +std::vector shape_transform_descriptor::lens() const +{ + std::vector result; + std::transform(dimensions.begin(), + dimensions.end(), + std::back_inserter(result), + [](const dimension& d) { return d.len(); }); return result; } @@ -823,6 +1162,54 @@ std::size_t shape_transform_descriptor::elements() const std::multiplies<>{}, [](const auto& s) { return s.len(); }); } +std::vector +shape_transform_descriptor::common_dims(const std::vector& input_dims) const +{ + std::vector result; + for(const auto& d : dimensions) + { + std::transform(d.subdimensions.begin(), + d.subdimensions.end(), + std::back_inserter(result), + [&](const dimension::sub& s) { return get_len(s, input_dims); }); + } + return result; +} + +const std::vector& shape_transform_descriptor::dimension::sub::origin_axis() const +{ + return axis.empty() ? hidden_axis : axis; +} +bool shape_transform_descriptor::dimension::sub::has_hidden_axis() const +{ + return axis.empty() and not hidden_axis.empty(); +} + +void shape_transform_descriptor::dimension::sub::add_split_axis(std::size_t i) +{ + if(not axis.empty()) + axis.push_back(i); + if(not hidden_axis.empty()) + hidden_axis.push_back(i); +} + +void shape_transform_descriptor::dimension::sub::expose() +{ + if(has_hidden_axis()) + { + axis = hidden_axis; + hidden_axis.clear(); + } +} + +void shape_transform_descriptor::dimension::sub::hide() +{ + if(not has_hidden_axis()) + { + hidden_axis = axis; + axis.clear(); + } +} bool operator==(const dimension::sub& x, const dimension::sub& y) { @@ -833,8 +1220,8 @@ bool operator!=(const dimension::sub& x, const dimension::sub& y) { return not(x std::ostream& operator<<(std::ostream& os, const dimension::sub& x) { os << x.len << ":" << to_string_range(x.axis, "x"); - if(x.hidden_axis.has_value()) - os << "$" << x.hidden_axis.value(); + if(not x.hidden_axis.empty()) + os << "$" << to_string_range(x.hidden_axis, "x"); return os; } bool operator==(const dimension& x, const dimension& y) @@ -866,13 +1253,10 @@ std::ostream& operator<<(std::ostream& os, const shape_transform_descriptor& x) std::vector optimize_shape_transforms(const std::vector& dims, const std::vector& ops) { - shape_transform_descriptor sd{dims}; - if(not sd.apply(ops)) + auto sd = shape_transform_descriptor::create(dims, ops); + if(sd.empty()) return ops; - sd.simplify(); - auto result = sd.generate(); - assert(compute_dims(ops, dims) == compute_dims(result, dims)); - return result; + return sd.generate(); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index a4fc6f6c041..6216900c98e 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -57,6 +57,22 @@ auto conv_const_weights() match::args(match::none_of(match::is_constant()), match::is_constant().bind("w"))); } +auto from_int4() +{ + return match::make_predicate_matcher([](instruction_ref start) { + return fix([&](auto self, instruction_ref ins) { + auto alias = instruction::get_output_alias(ins); + if(contains({"reshape", "dequantizelinear"}, alias->name())) + return self(alias->inputs().front()); + if(alias->name() == "concat") + return all_of(alias->inputs(), self); + return alias->name() == "unpack_int4"; + })(start); + }); +} + +auto not_from_int4() { return match::none_of(from_int4()); } + auto reduction() { return match::name_contains("reduce"); } // conv(x, w) * a => conv(x, a * w) @@ -208,8 +224,8 @@ struct find_mul_dot { auto matcher() const { - auto is_dot_const_inputs = - match::name("dot")(match::any_of[match::inputs()](match::is_constant())); + auto constant = match::is_constant(not_from_int4()); + auto is_dot_const_inputs = match::name("dot")(match::any_of[match::inputs()](constant)); return match::name("mul")(match::either_arg(0, 1)( is_dot_const_inputs.bind("dot"), match::name("broadcast", "multibroadcast").bind("c"))); } @@ -358,7 +374,8 @@ struct find_dot_mul match::used_once(), match::either_arg(0, 1)(const_broadcast.bind("d"), match::none_of(match::is_constant()).bind("z"))); - return match::name("dot")(match::either_arg(0, 1)(mul, match::is_constant().bind("c"))); + return match::name("dot")( + match::either_arg(0, 1)(mul, match::is_constant(not_from_int4()).bind("c"))); } void apply(module& m, const match::matcher_result& r) const @@ -915,7 +932,8 @@ struct find_concat_op auto matcher() const { return match::name("concat")(match::any_of[match::inputs()]( - match::any_of(match::pointwise(), match::name("broadcast", "multibroadcast")), + match::any_of(match::pointwise(), + match::name("broadcast", "multibroadcast", "unpack_int4")), match::used_once())); } @@ -935,7 +953,7 @@ struct find_concat_op static bool is_valid_op(const operation& op) { - return contains({"broadcast", "multibroadcast"}, op.name()) or + return contains({"broadcast", "multibroadcast", "unpack_int4"}, op.name()) or op.attributes().contains("pointwise"); } @@ -951,6 +969,17 @@ struct find_concat_op }); } + static bool rejected_inputs(const std::vector& inputs) + { + if(inputs.empty()) + return true; + if(inputs.size() < 3) + return false; + auto nonconst = std::count_if( + inputs.begin(), inputs.end(), [](instruction_ref ins) { return not ins->can_eval(); }); + return nonconst > 2; + } + void apply(module& m, const match::matcher_result& r) const { auto ins = r.result; @@ -960,7 +989,7 @@ struct find_concat_op if(std::distance(start, last) < 2) return {start, last}; auto x = *start; - if(x->inputs().size() > 2 or x->inputs().empty() or x->outputs().size() > 1) + if(x->outputs().size() > 1 or rejected_inputs(x->inputs())) return {start, last}; auto op = x->get_operator(); if(not is_valid_op(op)) diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index 3fb781a0f8d..86c2100a995 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -35,6 +35,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -143,10 +144,8 @@ struct match_find_quantizable_ops auto zp1 = r.instructions["zp1"]; auto zp2 = r.instructions["zp2"]; // Only INT8 or FP8 type currently supported - std::set supported_types = {migraphx::shape::fp8e4m3fnuz_type, - migraphx::shape::fp8e4m3fn_type, - migraphx::shape::fp8e5m2_type, - migraphx::shape::int8_type}; + std::set supported_types = fp8_types{}.get(); + supported_types.insert(migraphx::shape::int8_type); if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or not contains(supported_types, dq2->inputs().front()->get_shape().type())) return; @@ -416,6 +415,28 @@ void remove_qdq_pairs(module& m) } } +void remove_zero_point(module& m) +{ + for(auto ins : iterator_for(m)) + { + if(ins->name() != "dequantizelinear") + continue; + if(ins->inputs().size() != 3) + continue; + auto zp = ins->inputs().at(2); + if(not zp->can_eval()) + continue; + auto a = zp->eval(); + bool is_zero = false; + a.visit([&](auto t) { + is_zero = std::all_of(t.begin(), t.end(), [](auto x) { return float_equal(x, 0); }); + }); + if(not is_zero) + continue; + m.replace_instruction(ins, ins->get_operator(), ins->inputs().at(0), ins->inputs().at(1)); + } +} + void add_int4_pack_unpack_pair(module& m) { for(auto ins : iterator_for(m)) @@ -446,6 +467,8 @@ void simplify_qdq::apply(module& m) const remove_qdq_pairs(m); migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); match::find_matches(m, match_qlinear_reused{}); + migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); + remove_zero_point(m); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index c642c9ee31a..d74fc48c3f9 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -462,6 +462,74 @@ struct find_concat_transpose } }; +struct find_concat_reshape +{ + auto matcher() const + { + return match::name("concat")(match::all_of[match::inputs()]( + match::name("reshape", "unsqueeze", "squeeze", "lazy_reshape"))); + } + + void apply(module& m, const match::matcher_result& mr) const + { + auto ins = mr.result; + auto concat_shape = ins->get_shape(); + auto reshapes = ins->inputs(); + if(reshapes.empty()) + return; + auto input_shape = reshapes.front()->inputs().front()->get_shape(); + // All inputs should have the same dimensions + if(not std::all_of( + std::next(reshapes.begin()), reshapes.end(), [&](instruction_ref reshape) { + return reshape->inputs().front()->get_shape().lens() == input_shape.lens(); + })) + return; + // axis could be a negative value + auto op = any_cast(ins->get_operator()); + int64_t n_dim = reshapes.front()->get_shape().lens().size(); + auto axis = tune_axis(n_dim, op.axis, op.name()); + + auto predims = std::accumulate(concat_shape.lens().begin(), + concat_shape.lens().begin() + axis, + std::size_t{1}, + std::multiplies<>{}); + auto postdims = std::accumulate(concat_shape.lens().begin() + axis + 1, + concat_shape.lens().end(), + std::size_t{1}, + std::multiplies<>{}); + + // Find the axis on the input + std::size_t x = 1; + auto it = std::find_if(input_shape.lens().begin(), input_shape.lens().end(), [&](auto d) { + x *= d; + return x > predims; + }); + if(it == input_shape.lens().end()) + return; + op.axis = it - input_shape.lens().begin(); + auto ipredims = std::accumulate(input_shape.lens().begin(), + input_shape.lens().begin() + op.axis, + std::size_t{1}, + std::multiplies<>{}); + if(ipredims != predims) + return; + auto ipostdims = std::accumulate(input_shape.lens().begin() + op.axis + 1, + input_shape.lens().end(), + std::size_t{1}, + std::multiplies<>{}); + if(ipostdims != postdims) + return; + + std::vector inputs; + std::transform(reshapes.begin(), + reshapes.end(), + std::back_inserter(inputs), + [&](instruction_ref i) { return i->inputs().front(); }); + auto concat = m.insert_instruction(ins, op, inputs); + m.replace_instruction(ins, make_op("reshape", {{"dims", concat_shape.lens()}}), concat); + } +}; + struct find_nested_concat { auto matcher() const @@ -1107,6 +1175,19 @@ struct find_mul_add_shape_op_dot } }; +struct find_flatten +{ + auto matcher() const { return match::name("flatten"); } + + void apply(module& m, const match::matcher_result& r) const + { + auto flatten = r.result; + m.replace_instruction(flatten, + make_op("reshape", {{"dims", flatten->get_shape().lens()}}), + flatten->inputs()); + } +}; + void simplify_reshapes::apply(module& m) const { m.repeat_while_changes(depth, [&] { @@ -1114,10 +1195,12 @@ void simplify_reshapes::apply(module& m) const find_where_op{}, find_resize{}, find_nop_reshapes{}, + find_flatten{}, find_reshape_cont{}, find_nested_shape_transforms{}, find_concat_slice{}, find_concat_transpose{}, + find_concat_reshape{}, find_concat_multibroadcasts{}, find_nested_slice{}, find_nested_concat{}, diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 91bdfc9924f..3188b00563f 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -237,8 +237,6 @@ void split_reduce::apply(module_pass_manager& mpm) const assert(replaced.size() == 1); mpm.get_module().replace_instruction(ins, replaced.front()); } - - mpm.run_pass(fuse_pointwise{.enable_rewrite_broadcasts = true}); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/targets/cpu/include/migraphx/cpu/dnnl.hpp b/src/targets/cpu/include/migraphx/cpu/dnnl.hpp index 877eab95398..b05cad85246 100644 --- a/src/targets/cpu/include/migraphx/cpu/dnnl.hpp +++ b/src/targets/cpu/include/migraphx/cpu/dnnl.hpp @@ -167,9 +167,10 @@ struct dnnl_op : auto_register_op auto desc = prim.get_primitive_desc(); const char* str = nullptr; #ifdef MIGRAPHX_ENABLE_ZENDNN - zendnn_primitive_desc_query(desc, zendnn_query_impl_info_str, 0, &str); + zendnn_primitive_desc_query( + desc, zendnn_query_impl_info_str, 0, reinterpret_cast(&str)); #else - dnnl_primitive_desc_query(desc, dnnl_query_impl_info_str, 0, &str); + dnnl_primitive_desc_query(desc, dnnl_query_impl_info_str, 0, reinterpret_cast(&str)); #endif return str == nullptr ? "" : str; } diff --git a/src/targets/cpu/target.cpp b/src/targets/cpu/target.cpp index 6e4e4051a80..e148aa5b6f3 100644 --- a/src/targets/cpu/target.cpp +++ b/src/targets/cpu/target.cpp @@ -33,7 +33,6 @@ #include #include #include -#include #include #include #include diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt index 6b5a0521785..82cc1fb0a3c 100644 --- a/src/targets/gpu/CMakeLists.txt +++ b/src/targets/gpu/CMakeLists.txt @@ -85,7 +85,7 @@ add_library(migraphx_device ${DEVICE_GPU_SRCS}) add_library(compile_for_gpu INTERFACE) target_compile_features(compile_for_gpu INTERFACE cxx_std_17) -target_compile_options(compile_for_gpu INTERFACE -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns) +target_compile_options(compile_for_gpu INTERFACE -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fnative-half-arguments-and-returns) target_link_options(compile_for_gpu INTERFACE -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored) target_link_libraries(compile_for_gpu INTERFACE hip::device) check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE) @@ -118,6 +118,7 @@ foreach(KERNEL_FILE ${KERNEL_FILES}) endforeach() target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_NLOCAL=256) +target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_WAVEFRONTSIZE=64) target_include_directories(kernel_file_check PRIVATE $) target_link_libraries(kernel_file_check compile_for_gpu) if(MIGRAPHX_USE_COMPOSABLEKERNEL) @@ -148,6 +149,7 @@ add_library(migraphx_gpu compile_gen.cpp compile_hip.cpp compile_hip_code_object.cpp + compile_hipblaslt.cpp compile_miopen.cpp compile_pointwise.cpp compiler.cpp diff --git a/src/targets/gpu/compile_hip.cpp b/src/targets/gpu/compile_hip.cpp index 08c7a8c7d99..58b51872552 100644 --- a/src/targets/gpu/compile_hip.cpp +++ b/src/targets/gpu/compile_hip.cpp @@ -201,16 +201,6 @@ std::vector> compile_hip_src_with_hiprtc(std::vector 0); assert(options.local > 0); @@ -191,6 +192,8 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option options.emplace_param("-DMIGRAPHX_NGLOBAL=" + std::to_string(options.global)); options.emplace_param("-DMIGRAPHX_NLOCAL=" + std::to_string(options.local)); + options.emplace_param("-DMIGRAPHX_WAVEFRONTSIZE=" + + std::to_string(ctx.get_current_device().get_wavefront_size())); const auto& warnings = compiler_warnings(); options.params.insert(options.params.end(), warnings.begin(), warnings.end()); options.emplace_param("-ftemplate-backtrace-limit=0"); diff --git a/src/targets/gpu/compile_hipblaslt.cpp b/src/targets/gpu/compile_hipblaslt.cpp new file mode 100644 index 00000000000..c320e6b7dae --- /dev/null +++ b/src/targets/gpu/compile_hipblaslt.cpp @@ -0,0 +1,78 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#if MIGRAPHX_USE_HIPBLASLT +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +static size_t compile(migraphx::context& ctx, operation& op, instruction_ref ins) +{ + auto v = op.compile(ctx, ins->get_shape(), to_shapes(ins->inputs())); + return v.get("workspace", 0); +} + +void compile_hipblaslt::apply(module& m) const +{ + assert(ctx); + for(auto ins : iterator_for(m)) + { + if(ins->name() != "gpu::hipblaslt_op") + continue; + auto op = any_cast(ins->get_operator()).op; + auto inputs = ins->inputs(); + + std::size_t ws = hipblaslt_workspace_size; + + auto alloc = m.insert_instruction( + ins, make_op("allocate", {{"shape", to_value(shape{shape::uint8_type, {ws}})}})); + inputs.insert(std::prev(inputs.end()), alloc); + m.replace_instruction(ins, op, inputs); + + // Calculate workspace size + ws = compile(*ctx, op, ins); + auto alloc_after = m.insert_instruction( + ins, make_op("allocate", {{"shape", to_value(shape{shape::uint8_type, {ws}})}})); + + // Replace the workspace size with actual worksapce size needed. + auto it = std::find(inputs.begin(), inputs.end(), alloc); + if(it != inputs.end()) + { + *it = alloc_after; // Replace `alloc` with `alloc_after` + } + m.replace_instruction(ins, op, inputs); + } +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_USE_HIPBLASLT diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/float_equal.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/float_equal.hpp index 9fb6f858d18..a5f18fc5aa8 100644 --- a/src/targets/gpu/device/include/migraphx/gpu/device/float_equal.hpp +++ b/src/targets/gpu/device/include/migraphx/gpu/device/float_equal.hpp @@ -44,6 +44,16 @@ __device__ bool float_equal_device(T x, T y) std::nextafter(x, std::numeric_limits::max()) >= y; } +template <> +__device__ bool float_equal_device(__bf16 x, __bf16 y) // NOLINT(misc-definitions-in-headers) +{ + float xf = x; + float yf = y; + return std::isfinite(xf) and std::isfinite(yf) and + std::nextafter(xf, std::numeric_limits::lowest()) <= yf and + std::nextafter(xf, std::numeric_limits::max()) >= yf; +} + template {})> __device__ bool float_equal_device(T x, T y) { diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp index 19fb02763fb..c9f2e3d7cd4 100644 --- a/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp +++ b/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -67,6 +68,7 @@ auto pack_vec(Ts... xs) } using gpu_half = __fp16; +using gpu_bf16 = __bf16; namespace detail { template @@ -87,6 +89,12 @@ struct device_type using type = gpu_half; }; +template <> +struct device_type +{ + using type = gpu_bf16; +}; + template struct host_type { @@ -99,6 +107,12 @@ struct host_type using type = half; }; +template <> +struct host_type +{ + using type = bf16; +}; + } // namespace detail template @@ -143,23 +157,53 @@ __device__ __host__ T to_hip_type(T x) return x; } -// Hip doens't support __fp16 +// Hip doens't support __fp16 and __bf16 inline __device__ __host__ float to_hip_type(gpu_half x) { return x; } +inline __device__ __host__ float to_hip_type(gpu_bf16 x) { return x; } + +template +struct is_floating_point : std::is_floating_point +{ +}; + +template <> +struct is_floating_point<__fp16> : std::true_type +{ +}; + +template +struct is_signed : std::is_signed +{ +}; + +template <> +struct is_signed<__fp16> : std::true_type +{ +}; + +template +struct is_arithmetic : std::is_arithmetic +{ +}; + +template <> +struct is_arithmetic<__fp16> : std::true_type +{ +}; -#define MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(trait, T) \ - template \ - struct trait : std::trait \ - { \ - }; \ - \ - template <> \ - struct trait : std::true_type \ - { \ - }; - -MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16) -MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16) -MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16) +// Redo for __bf16 +template <> +struct is_floating_point<__bf16> : std::true_type +{ +}; +template <> +struct is_signed<__bf16> : std::true_type +{ +}; +template <> +struct is_arithmetic<__bf16> : std::true_type +{ +}; } // namespace device } // namespace gpu diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp index 18981399364..78f28a552bd 100644 --- a/src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp +++ b/src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp @@ -98,6 +98,10 @@ template <> struct is_hip_type : std::true_type { }; +template <> +struct is_hip_type : std::true_type +{ +}; template {})> void hip_visitor_invoke(T as, V&& v) diff --git a/src/targets/gpu/device_name.cpp b/src/targets/gpu/device_name.cpp index e3f47a4f3b0..c717742e2c7 100644 --- a/src/targets/gpu/device_name.cpp +++ b/src/targets/gpu/device_name.cpp @@ -52,13 +52,15 @@ std::string get_device_name() bool gfx_has_fp8fnuz_intrinsics() { const auto device_name = trim(split_string(get_device_name(), ':').front()); - return (starts_with(device_name, "gfx9") and device_name >= "gfx940"); + return (starts_with(device_name, "gfx94")); } bool gfx_has_fp8ocp_intrinsics() { const auto device_name = trim(split_string(get_device_name(), ':').front()); - return (starts_with(device_name, "gfx12") and device_name >= "gfx1200"); + bool is_navi_with_fp8ocp = starts_with(device_name, "gfx12") and device_name >= "gfx1200"; + bool is_mi_with_fp8ocp = starts_with(device_name, "gfx9") and device_name >= "gfx950"; + return (is_navi_with_fp8ocp or is_mi_with_fp8ocp); } } // namespace gpu diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 4376bb323cc..65a27a76ad1 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -130,6 +131,8 @@ bool mlir_attention_enabled(context* ctx) #ifdef MIGRAPHX_MLIR if(not mlir_enabled()) return false; + if(specific_op("attention")) + return false; // Enable attention by default for mi300 if(ctx != nullptr and starts_with(ctx->get_current_device().get_gfx_name(), "gfx94")) return true; @@ -152,14 +155,40 @@ struct mlir_op return pack(f(self.op, "op")); } + // Check if the shape can be created from a transpose/broadcast/slice + static bool is_mlir_compatible(const shape& s) + { + if(s.standard() or s.packed() or s.scalar() or s.ndim() == 1) + return true; + auto ns = reorder_shape(s, find_permutation(s)); + std::vector stride_ratios; + auto last = std::find(ns.strides().begin(), ns.strides().end(), 0); + if(*std::prev(last) != 1) + return false; + std::adjacent_difference(ns.strides().begin(), + last, + std::back_inserter(stride_ratios), + [](auto y, auto x) -> std::size_t { + assert(y != 0); + if((x % y) != 0) + return 0; + return x / y; + }); + return std::equal(stride_ratios.begin() + 1, + stride_ratios.end(), + ns.lens().begin() + 1, + [](auto ratio, auto len) { return ratio >= len; }); + } + shape compute_shape(const std::vector& inputs, const std::vector& mods) const { module_ref mod = mods[0]; - check_shapes{inputs, *this}.packed_or_broadcasted(); + check_shapes{inputs, *this}.has_at_least(1); if(mods.size() != 1) MIGRAPHX_THROW("should have one submodule."); - if(inputs.empty()) - MIGRAPHX_THROW("should have at least one input."); + + if(not std::all_of(inputs.begin(), inputs.end(), &is_mlir_compatible)) + MIGRAPHX_THROW("Shape is not mlir compatible."); auto result = mod->compute_shapes(inputs, {.name = name(), .strict_type = true, .strict_lens = true}); @@ -210,7 +239,7 @@ void fuse_input_ops(module_ref mm, std::unordered_map* map_ins) { assert(map_ins != nullptr); - size_t input_cnt = 0; + size_t input_cnt = mm->get_parameters().size(); for(instruction_ref input : inputs) { if(contains(*map_ins, input)) @@ -298,8 +327,8 @@ auto is_mlir_conv(mlir_mode mode) // Avoid MLIR assertion: Index < Length && "Invalid index!" if(ins->get_shape().lens().size() != 4 and group > 1) return false; - std::set supported_types = { - shape::fp8e4m3fnuz_type, shape::fp8e4m3fn_type, shape::fp8e5m2_type, shape::int8_type}; + std::set supported_types = fp8_types{}.get(); + supported_types.insert(shape::int8_type); if(contains(supported_types, input.type())) return true; if(mode == mlir_mode::all) @@ -361,12 +390,16 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) const auto& name = i.name(); const auto result_type = i.get_shape().type(); const std::initializer_list allowed_types = {type_t::float_type, + type_t::bf16_type, type_t::half_type, type_t::fp8e4m3fnuz_type, + type_t::fp8e5m2fnuz_type, type_t::fp8e4m3fn_type, type_t::fp8e5m2_type, type_t::int8_type, + type_t::uint8_type, type_t::int32_type, + type_t::uint32_type, type_t::bool_type}; // Preliminary type check. if(not contains(allowed_types, result_type)) @@ -407,7 +440,9 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) }; std::set float_types = {type_t::float_type, type_t::half_type, + type_t::bf16_type, type_t::fp8e4m3fnuz_type, + type_t::fp8e5m2fnuz_type, type_t::fp8e4m3fn_type, type_t::fp8e5m2_type}; bool is_float = contains(float_types, result_type); @@ -426,7 +461,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) return false; } // else return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) { - return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type()); + return contains({type_t::float_type, type_t::half_type, type_t::bf16_type}, + arg->get_shape().type()); }); } return false; @@ -437,8 +473,14 @@ bool is_reduce_op_supported_by_mlir(const instruction& i) using type_t = shape::type_t; const auto& name = i.name(); const auto result_type = i.get_shape().type(); - const std::initializer_list allowed_types = { - type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}; + const std::initializer_list allowed_types = {type_t::float_type, + type_t::half_type, + type_t::bf16_type, + type_t::fp8e4m3fnuz_type, + type_t::fp8e5m2fnuz_type, + type_t::fp8e4m3fn_type, + type_t::fp8e5m2_type}; + // Preliminary type check. if(not contains(allowed_types, result_type)) { @@ -695,8 +737,10 @@ struct find_mlir_standalone_op if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) { return not contains({shape::type_t::float_type, shape::type_t::half_type, + shape::type_t::bf16_type, shape::type_t::int8_type, shape::type_t::fp8e4m3fnuz_type, + shape::type_t::fp8e5m2fnuz_type, shape::type_t::fp8e4m3fn_type, shape::type_t::fp8e5m2_type}, i->get_shape().type()); @@ -869,10 +913,11 @@ struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op struct find_pointwise_mlir { + auto supported_pointwise() const { return mlir_input_pointwise(match::used_once()); } + auto matcher() const { - return match::name("gpu::mlir_op")(match::any_of[match::inputs()]( - mlir_input_pointwise(match::used_once()).bind("pointwise"))); + return match::name("gpu::mlir_op")(match::any_of[match::inputs()](supported_pointwise())); } static bool is_simple_op(const_module_ref pm, std::initializer_list op_names) @@ -905,23 +950,44 @@ struct find_pointwise_mlir void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto ins = r.result; - auto pw = r.instructions["pointwise"]; auto* mm = ins->module_inputs().front(); - auto* pm = pw->module_inputs().front(); + std::vector pws; + std::copy_if( + ins->inputs().begin(), + ins->inputs().end(), + std::back_inserter(pws), + [&](instruction_ref input) { + if(not match::instruction_matches(mpm.get_module(), input, supported_pointwise())) + return false; + auto* pm = input->module_inputs().front(); + if(input->inputs().size() > 1 and not is_simple_op(pm, {"dequantizelinear"})) + { + if(not enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) + return false; + } + return true; + }); + if(pws.empty()) + return; - if(pw->inputs().size() > 1 and not is_simple_op(pm, {"dequantizelinear"})) - { - if(not enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) - return; - } + std::string module_name; + std::transform( + pws.begin(), pws.end(), join_back_inserter(module_name), [](instruction_ref pw) { + return pw->module_inputs().front()->name() + ":"; + }); + module_name += mm->name(); + module_ref m = mpm.create_module(module_name); + m->set_bypass(); std::unordered_map map_ins; - module_ref m = mpm.create_module(pm->name() + ":" + mm->name()); - m->set_bypass(); - fuse_input_ops(m, pw->inputs(), &map_ins); - auto rins = m->fuse(*pm, pw->inputs(), &map_ins, &insert_pointwise).front(); - map_ins[pw] = rins; + for(auto pw : pws) + { + auto* pm = pw->module_inputs().front(); + fuse_input_ops(m, pw->inputs(), &map_ins); + auto rins = m->fuse(*pm, pw->inputs(), &map_ins, &insert_pointwise).front(); + map_ins[pw] = rins; + } auto ret = m->fuse(*mm, ins->inputs(), &map_ins); m->add_return({ret}); diff --git a/src/targets/gpu/fuse_ops.cpp b/src/targets/gpu/fuse_ops.cpp index 170f67f3fb2..5e93ccf5ecf 100644 --- a/src/targets/gpu/fuse_ops.cpp +++ b/src/targets/gpu/fuse_ops.cpp @@ -24,11 +24,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include #include @@ -41,6 +43,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPBLASLT_GEMM) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MIOPEN_FUSION) #if MIGRAPHX_USE_MIOPEN struct fusion @@ -555,20 +558,9 @@ struct find_conv_pointwise }; #endif -#if MIGRAPHX_USE_ROCBLAS -struct find_gemm_pointwise +#if MIGRAPHX_USE_ROCBLAS or MIGRAPHX_USE_HIPBLASLT +struct gemm_pointwise { - auto matcher() const - { - auto gemm_op = match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm"); - auto binary_op = match::all_of( - match::nargs(3), - match::either_arg(0, 1)( - match::any_of(match::standard_shape(), match::is_constant()).bind("c"), gemm_op)); - auto unary_op = match::all_of(match::nargs(2), match::arg(0)(gemm_op)); - return precompile_name("pointwise")(match::any_of(binary_op, unary_op)); - } - // TODO: Move to matcher.hpp static auto match_param(const std::string& name) { @@ -642,6 +634,22 @@ struct find_gemm_pointwise return false; } } +}; +#endif + +#if MIGRAPHX_USE_ROCBLAS +struct find_rocblas_gemm_pointwise : gemm_pointwise +{ + auto matcher() const + { + auto gemm_op = match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm"); + auto binary_op = match::all_of( + match::nargs(3), + match::either_arg(0, 1)( + match::any_of(match::standard_shape(), match::is_constant()).bind("c"), gemm_op)); + auto unary_op = match::all_of(match::nargs(2), match::arg(0)(gemm_op)); + return precompile_name("pointwise")(match::any_of(binary_op, unary_op)); + } void apply(module& m, const match::matcher_result& r) const { @@ -669,7 +677,7 @@ struct find_gemm_pointwise shape s = c_ins->get_shape(); // const-fold input if not standard shape since rocblas can't handle it // Updated for a case where "standard" shape has out-of-sequence strides - if(not s.standard() or s.normalize_standard() != s) + if(not s.standard()) { auto c = make_op("contiguous"); auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()}); @@ -685,16 +693,71 @@ struct find_gemm_pointwise }; #endif -struct find_contiguous_tranpose_gemm +#if MIGRAPHX_USE_HIPBLASLT +struct find_hipblas_gemm_pointwise : gemm_pointwise { auto matcher() const { - return match::name("gpu::contiguous")(match::arg(0)( - match::name("transpose")( - match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm"))) - .bind("transpose"))); + auto gemm_op = + match::name("gpu::hipblaslt_op")(match::nargs(3), match::used_once()).bind("hip_gemm"); + auto binary_op = match::all_of( + match::nargs(3), + match::either_arg(0, 1)( + match::any_of(match::standard_shape(), match::is_constant()).bind("c"), gemm_op)); + auto unary_op = match::all_of(match::nargs(2), match::arg(0)(gemm_op)); + return precompile_name("pointwise")(match::any_of(binary_op, unary_op)); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto gemm_ins = r.instructions["hip_gemm"]; + + auto gemm_op = any_cast(gemm_ins->get_operator()).op; + + if(gemm_op.name() != "gpu::hip_gemm") + return; + + auto gemm = any_cast>(gemm_op); + + // Already fused gemm + if(not float_equal(gemm.beta, 0)) + return; + if(ins->inputs().size() == 3) + gemm.beta = 1; + if(not update_gemm( + gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1)) + { + return; + } + auto inputs = gemm_ins->inputs(); + inputs.pop_back(); + if(ins->inputs().size() == 3) + { + auto c_ins = r.instructions["c"]; + shape s = c_ins->get_shape(); + // const-fold input if not standard shape + // Updated for a case where "standard" shape has out-of-sequence strides + if(not s.standard()) + { + auto c = make_op("contiguous"); + auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()}); + c_ins = m.add_literal(l.get_shape(), l.data()); + } + inputs.push_back(c_ins); + } + inputs.push_back(ins->inputs().back()); + + operation new_gemm_op = gemm; + auto new_ins = m.insert_instruction( + ins, make_op("gpu::hipblaslt_op", {{"op", to_value(new_gemm_op)}}), inputs); + m.replace_instruction(ins, new_ins); } +}; +#endif +struct contiguous_transpose_gemm +{ template static bool is_swapped(const Vector& perm, std::size_t i, std::size_t j) { @@ -705,6 +768,17 @@ struct find_contiguous_tranpose_gemm std::swap(perm2[i], perm2[j]); return perm2 == perm; } +}; + +struct find_contiguous_transpose_rocblas_gemm : contiguous_transpose_gemm +{ + auto matcher() const + { + return match::name("gpu::contiguous")(match::arg(0)( + match::name("transpose")( + match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm"))) + .bind("transpose"))); + } void apply(module& m, const match::matcher_result& r) const { @@ -743,6 +817,67 @@ struct find_contiguous_tranpose_gemm } }; +#if MIGRAPHX_USE_HIPBLASLT +struct find_contiguous_transpose_hip_gemm : contiguous_transpose_gemm +{ + auto matcher() const + { + return match::name("gpu::contiguous")(match::arg(0)( + match::name("transpose")( + match::arg(0)( + match::name("gpu::hipblaslt_op")(match::used_once()).bind("hip_gemm"))) + .bind("transpose"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto gemm_ins = r.instructions["hip_gemm"]; + auto gemm_op = any_cast(gemm_ins->get_operator()).op; + + if(gemm_op.name() != "gpu::hip_gemm") + return; + + auto gemm = any_cast>(gemm_op); + + auto alloc = gemm_ins->inputs().back(); + auto transpose = r.instructions["transpose"]; + auto perm = transpose->get_operator().to_value()["permutation"].to_vector(); + auto iperm = invert_permutation(perm); + + if(perm.size() < 3) + return; + + if(not is_swapped(perm, perm.size() - 3, perm.size() - 2)) + return; + + auto lens = gemm_ins->get_shape().lens(); + if(lens.size() > 3 and + not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; })) + return; + + gemm.trans_batch = 1; + + auto s = shape{alloc->get_shape().type(), reorder_dims(alloc->get_shape().lens(), iperm)}; + auto new_alloc = + m.insert_instruction(gemm_ins, make_op("allocate", {{"shape", to_value(s)}})); + + auto alloc_transpose = m.insert_instruction( + gemm_ins, make_op("transpose", {{"permutation", perm}}), new_alloc); + + auto inputs = gemm_ins->inputs(); + inputs.back() = alloc_transpose; + operation new_gemm_op = gemm; + auto new_gemm = m.insert_instruction( + gemm_ins, make_op("gpu::hipblaslt_op", {{"op", to_value(new_gemm_op)}}), inputs); + + auto gemm_transpoe = m.insert_instruction(gemm_ins, transpose->get_operator(), new_gemm); + + m.replace_instruction(ins, gemm_transpoe); + } +}; +#endif + struct find_commutative_broadcast { auto matcher() const @@ -835,7 +970,7 @@ struct find_layernorm_pointwise { auto matcher() const { - return precompile_name("pointwise")(match::any_of[match::inputs()]( + return precompile_name("pointwise")(match::arg(0)( precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm"))); } @@ -903,13 +1038,19 @@ void fuse_ops::apply(module& m) const match::find_matches(m, find_conv_pointwise{ctx}, find_conv_bias_relu{ctx}, find_conv_bias{ctx}); run_passes(m, {dead_code_elimination{}}); #endif - match::find_matches(m, #if MIGRAPHX_USE_ROCBLAS - find_gemm_pointwise{}, + match::find_matches(m, find_rocblas_gemm_pointwise{}); +#endif +#if MIGRAPHX_USE_HIPBLASLT + match::find_matches(m, find_hipblas_gemm_pointwise{}); #endif + match::find_matches(m, find_layernorm_pointwise{}, find_concat_pointwise{}, - find_contiguous_tranpose_gemm{}, + find_contiguous_transpose_rocblas_gemm{}, +#if MIGRAPHX_USE_HIPBLASLT + find_contiguous_transpose_hip_gemm{}, +#endif find_commutative_broadcast{}); match::find_matches(m, find_contiguous{}); } diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index 19d4f056deb..d0f750a2501 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -64,6 +64,7 @@ rocblas_datatype get_type(shape::type_t type) case shape::int32_type: return rocblas_datatype_i32_r; case shape::uint32_type: return rocblas_datatype_u32_r; case shape::fp8e4m3fnuz_type: return rocblas_datatype_f8_r; + case shape::fp8e5m2fnuz_type: return rocblas_datatype_bf8_r; case shape::fp8e4m3fn_type: case shape::fp8e5m2_type: case shape::tuple_type: @@ -72,15 +73,17 @@ rocblas_datatype get_type(shape::type_t type) case shape::int16_type: case shape::int64_type: case shape::uint64_type: MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!"); + case shape::bf16_type: return rocblas_datatype_bf16_r; } MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!"); } -void blas_shape(const shape& s) +void blas_shape(const shape& in_shape) { - if(s.lens().size() < 2) + if(in_shape.lens().size() < 2) return; + auto s = in_shape.normalize_standard(); if(std::none_of(s.strides().end() - 2, s.strides().end(), [](auto i) { return i == 1; })) MIGRAPHX_THROW("GPU_GEMM: needs to have one matrix stride as 1"); if(std::any_of(s.strides().end() - 2, s.strides().end(), [](auto i) { return i == 0; })) @@ -221,7 +224,7 @@ struct gemm_impl compute_type = rb_compute_type{output_type}; if(compute_fp32) { - if(arg_type == rocblas_datatype_f16_r) + if(arg_type == rocblas_datatype_f16_r or arg_type == rocblas_datatype_bf16_r) compute_type = rocblas_datatype_f32_r; } if(arg_type == rocblas_datatype_f8_r) @@ -589,7 +592,7 @@ void gemm_compute(context& ctx, std::transform(args.begin(), args.end(), std::back_inserter(input_shapes), - [](const argument& x) { return x.get_shape(); }); + [](const argument& x) { return x.get_shape().normalize_standard(); }); auto gemm_item = gemm_impl(output_shape, input_shapes, alpha, beta, compute_fp32); gemm_item.run(ctx, args, solution_idx); } @@ -606,7 +609,7 @@ void gemm_compute(context& ctx, std::transform(args.begin(), args.end(), std::back_inserter(input_shapes), - [](const argument& x) { return x.get_shape(); }); + [](const argument& x) { return x.get_shape().normalize_standard(); }); auto gemm_item = gemm_impl(output_shape, input_shapes, alpha, beta, compute_fp32); gemm_item.run(ctx, args, solution_idx); } diff --git a/src/targets/gpu/hip_gemm_impl.cpp b/src/targets/gpu/hip_gemm_impl.cpp index f5ec898d8d5..966927da7b5 100644 --- a/src/targets/gpu/hip_gemm_impl.cpp +++ b/src/targets/gpu/hip_gemm_impl.cpp @@ -31,6 +31,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -70,23 +71,33 @@ hipDataType get_type_hipblas(shape::type_t type) case shape::int32_type: return HIP_R_32I; case shape::uint32_type: return HIP_R_32U; case shape::fp8e4m3fnuz_type: return HIP_R_8F_E4M3_FNUZ; + case shape::fp8e5m2fnuz_type: + return HIP_R_8F_E5M2_FNUZ; +// TODO can remove this preprocessor conditional when hip verison defaults to have these types +#ifdef ROCM_USE_FLOAT8 + case shape::fp8e4m3fn_type: return HIP_R_8F_E4M3; + case shape::fp8e5m2_type: return HIP_R_8F_E5M2; +#else case shape::fp8e4m3fn_type: case shape::fp8e5m2_type: +#endif case shape::tuple_type: case shape::bool_type: case shape::uint16_type: case shape::int16_type: case shape::int64_type: case shape::uint64_type: MIGRAPHX_THROW("HIPBLAS_GEMM: data type not supported!"); + case shape::bf16_type: return HIP_R_16BF; } MIGRAPHX_THROW("HIPBLAS_GEMM: data type not supported!"); } -void blas_shape_hip(const shape& s) +void blas_shape_hip(const shape& in_shape) { - if(s.lens().size() < 2) + if(in_shape.lens().size() < 2) return; + auto s = in_shape.normalize_standard(); if(std::none_of(s.strides().end() - 2, s.strides().end(), [](auto i) { return i == 1; })) MIGRAPHX_THROW("GPU_GEMM: needs to have one matrix stride as 1"); if(std::any_of(s.strides().end() - 2, s.strides().end(), [](auto i) { return i == 0; })) @@ -101,6 +112,19 @@ void blas_shape_hip(const shape& s) MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible"); } +shape transpose_batch_hip(const shape& s, unsigned trans_batch) +{ + if(trans_batch == 0) + return s; + if(s.lens().size() < 3) + return s; + auto batch = s.lens().size() - 3; + std::vector perm(s.lens().size()); + std::iota(perm.begin(), perm.end(), 0); + std::swap(perm[batch], perm[batch + trans_batch]); + return shape::from_permutation(s.type(), s.lens(), perm); +} + static bool is_transposed_hip(const shape& s) { return s.transposed() and s.strides().back() != 1; } static int32_t get_batch_stride_hip(const shape& s) @@ -393,7 +417,8 @@ struct hip_gemm_impl const std::vector& args, int32_t solution_idx) { - auto* algo = &solution.get_result(ctx, *this, solution_idx)[0].algo; + auto* algo = &solution.get_result(ctx, *this, solution_idx)[0].algo; + size_t workspace_size = ((is_3inputs ? args[3] : args[2]).get_shape()).bytes(); return pack(ctx.get_stream().get_hipblaslt(), hipblaslt_desc, get_alpha(), // alpha @@ -408,7 +433,7 @@ struct hip_gemm_impl is_3inputs ? mat_d : mat_c, // Ddesc algo, // algo is_3inputs ? args[3].data() : args[2].data(), // workspace - algo->max_workspace_bytes, // workspaceSizeInBytes + workspace_size, // workspaceSizeInBytes ctx.get_stream().get() // stream ); } @@ -462,10 +487,9 @@ struct hip_gemm_impl int32_t validate(context& ctx, const std::vector& input_args, int32_t solution_idx) // const { - hipblasStatus_t check_valid(HIPBLAS_STATUS_SUCCESS); auto common_args = create_hipblaslt_args_common(ctx, input_args, solution_idx); - check_valid = hipblaslt_invoke(&hipblasLtMatmul, common_args); - if(check_valid == HIPBLAS_STATUS_SUCCESS) + auto check_valid = hipblaslt_invoke(&hipblasLtMatmul, common_args, false); + if(check_valid != HIPBLAS_STATUS_SUCCESS) { std::cerr << "WARNING: tuned solution is invalid; reverting to default" << std::endl; return 0; @@ -473,6 +497,53 @@ struct hip_gemm_impl return solution_idx; } + /** + * Get workspace size for the solution index: Gets algo from the solution index, + * and calls matmulIsAlgoSupported() to get the workspace size. + */ + + size_t get_workspace_size(context& ctx, + const std::vector& input_shapes, + int32_t solution_idx) const + { + size_t workspace_size = hipblaslt_workspace_size; + std::vector input_args; + std::transform(input_shapes.begin(), + input_shapes.end(), + std::back_inserter(input_args), + [](const shape& x) { return to_gpu(generate_argument(x)); }); + + std::vector algo_index = {solution_idx}; + std::vector heuristic_result; + + hipblaslt_invoke([&]() { + return hipblaslt_ext::getAlgosFromIndex( + ctx.get_stream().get_hipblaslt(), algo_index, heuristic_result); + }); + assert(heuristic_result.size() == 1); + + auto algo = heuristic_result[0].algo; + size_t ret_workspace_size = 0; + auto supporting_args = + create_hipblaslt_supporting_args_common(ctx, input_args, algo, ret_workspace_size); + + auto status = + hipblaslt_invoke(&hipblaslt_ext::matmulIsAlgoSupported, supporting_args, false); + + // If algo is supported, update the workspace size to the actual size needed. + // Otherwise, use the default workspace size. + if(status == HIPBLAS_STATUS_SUCCESS) + { + // TODO: Remove this check once issues with '0' workspace size are resolved. + // Temporarily, we use the approach where, if the returned workspace size is '0', + // we use the default workspace size. + // Otherwise, we use the returned workspace size. + if(ret_workspace_size != 0) + workspace_size = ret_workspace_size; + } + return workspace_size; + } + /** * Find best hipBLASLt solution: Get list of solutions and try them all, returning the index * of the fastest one. @@ -526,6 +597,13 @@ struct hip_gemm_impl // Initialize to default solution index int32_t best_sol = 0; + // If no valid/supported solution is returned, use hipblasLtMatmulAlgoGetHeuristic + // to get an algo and use solution index from that algo. + if(solution_indices.empty()) + { + auto algo = solution.get_result(ctx, *this, 0)[0].algo; + solution_indices.push_back(hipblaslt_ext::getIndexFromAlgo(algo)); + } for(auto sol : solution_indices) { // Warmup: the first call to an op. may not be representative since there is @@ -606,7 +684,7 @@ void hip_gemm_compute(context& ctx, std::transform(args.begin(), args.end(), std::back_inserter(input_shapes), - [](const argument& x) { return x.get_shape(); }); + [](const argument& x) { return x.get_shape().normalize_standard(); }); auto gemm_item = hip_gemm_impl(output_shape, input_shapes, alpha, beta); gemm_item.run(ctx, args, solution_idx); } @@ -633,10 +711,19 @@ int32_t hip_gemm_finalize(context& ctx, float beta, int32_t solution_idx) { - auto gemm_item = hip_gemm_impl(output_shape, input_shapes, alpha, beta); - int32_t solution = gemm_item.tune(ctx, input_shapes); - hip_gemm_save_solution(ctx, output_shape, input_shapes, solution_idx); - return solution; + auto gemm_item = hip_gemm_impl(output_shape, input_shapes, alpha, beta); + if(solution_idx == 0) + { + solution_idx = gemm_item.tune(ctx, input_shapes); + hip_gemm_save_solution(ctx, output_shape, input_shapes, solution_idx); + } + // If a tuned solution index is already given, don't tune again but validate + // in case the data was tuned with a different hipBLASLt version. + else + { + solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx); + } + return solution_idx; } int32_t hip_gemm_default_solution(context& ctx, @@ -650,6 +737,17 @@ int32_t hip_gemm_default_solution(context& ctx, return 0; } +size_t hip_gemm_workspace_size(context& ctx, + const shape& output_shape, + const std::vector& input_shapes, + float alpha, + float beta, + int32_t solution_idx) +{ + auto gemm_item = hip_gemm_impl(output_shape, input_shapes, alpha, beta); + return gemm_item.get_workspace_size(ctx, input_shapes, solution_idx); +} + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/include/migraphx/gpu/code_object_op.hpp b/src/targets/gpu/include/migraphx/gpu/code_object_op.hpp index cdedd9cfb07..8186767289f 100644 --- a/src/targets/gpu/include/migraphx/gpu/code_object_op.hpp +++ b/src/targets/gpu/include/migraphx/gpu/code_object_op.hpp @@ -55,7 +55,8 @@ struct code_object_op f(self.global, "global"), f(self.local, "local"), f(self.expected_inputs, "expected_inputs"), - f(self.output, "output")); + f(self.output, "output"), + f(self.output_arg, "output_arg")); } value attributes() const { return {{"group", group()}}; } @@ -83,6 +84,8 @@ struct code_object_op os << "symbol_name=" << op.symbol_name << ","; os << "global=" << op.global << ","; os << "local=" << op.local << ","; + if(op.output_arg != -1) + os << "output_arg=" << op.output_arg << ","; os << "]"; return os; } diff --git a/src/targets/gpu/include/migraphx/gpu/compile_hip.hpp b/src/targets/gpu/include/migraphx/gpu/compile_hip.hpp index b3f654366e5..d2fa4bcb6c3 100644 --- a/src/targets/gpu/include/migraphx/gpu/compile_hip.hpp +++ b/src/targets/gpu/include/migraphx/gpu/compile_hip.hpp @@ -39,7 +39,6 @@ namespace gpu { #ifdef MIGRAPHX_USE_HIPRTC MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_HIPRTC); -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS); #endif struct hiprtc_src_file diff --git a/src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp b/src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp index 01d501a93dd..60b8f20a818 100644 --- a/src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp +++ b/src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp @@ -73,7 +73,8 @@ struct hip_compile_options MIGRAPHX_GPU_EXPORT std::function compute_global_for(context& ctx, std::size_t n, std::size_t over = 1); -MIGRAPHX_GPU_EXPORT operation compile_hip_code_object(const std::string& content, +MIGRAPHX_GPU_EXPORT operation compile_hip_code_object(context& ctx, + const std::string& content, hip_compile_options options); MIGRAPHX_GPU_EXPORT std::size_t diff --git a/src/targets/gpu/include/migraphx/gpu/compile_hipblaslt.hpp b/src/targets/gpu/include/migraphx/gpu/compile_hipblaslt.hpp new file mode 100644 index 00000000000..380fafa44ba --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/compile_hipblaslt.hpp @@ -0,0 +1,77 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_GPU_COMPILE_HIPBLASLT_HPP +#define MIGRAPHX_GUARD_GPU_COMPILE_HIPBLASLT_HPP + +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; +struct context; +struct operation; + +namespace gpu { + +struct hipblaslt_op +{ + operation op = op::identity{}; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.op, "op")); + } + + std::string name() const { return "gpu::hipblaslt_op"; } + + shape compute_shape(std::vector inputs) const + { + inputs.push_back(inputs.back()); + return op.compute_shape(inputs); + } + + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; +MIGRAPHX_REGISTER_OP(hipblaslt_op); + +struct compile_hipblaslt +{ + context* ctx = nullptr; + std::string name() const { return "gpu::compile_hipblaslt"; } + void apply(module& m) const; +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_GPU_COMPILE_HIPBLASLT_HPP diff --git a/src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp b/src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp index a4b14b9991a..8c3d67bcd93 100644 --- a/src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp +++ b/src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp @@ -41,6 +41,7 @@ namespace gpu { struct context; void blas_shape_hip(const shape& s); +shape transpose_batch_hip(const shape& s, unsigned trans_batch); template struct hip_gemm @@ -48,13 +49,16 @@ struct hip_gemm Op op; float alpha = 1; float beta = 0; + unsigned trans_batch = 0; int32_t solution_idx = 0; + template static auto reflect(Self& self, F f) { return pack_join(migraphx::reflect(self.op, f), pack(f(self.alpha, "alpha"), f(self.beta, "beta"), + f(self.trans_batch, "trans_batch"), f(self.solution_idx, "solution_idx"))); } @@ -98,10 +102,10 @@ struct hip_gemm to_string(cmat_shape.type()) + ", it must be: " + to_string(op_out_shape.type())); } - return op_out_shape; + return transpose_batch_hip(op_out_shape, trans_batch); } - return op.compute_shape(in_shapes); + return transpose_batch_hip(op.compute_shape(in_shapes), trans_batch); } argument @@ -126,6 +130,15 @@ struct hip_gemm hip_gemm_finalize(ctx, output_shape, input_shapes, alpha, beta, solution_idx); } } + + value + compile(migraphx::context& ctx, const shape& output, const std::vector& input_shapes) + { + finalize(any_cast(ctx), output, input_shapes); + size_t ws = hip_gemm_workspace_size( + any_cast(ctx), output, input_shapes, alpha, beta, solution_idx); + return {{"workspace", ws}}; + } }; } // namespace gpu } // namespace MIGRAPHX_INLINE_NS diff --git a/src/targets/gpu/include/migraphx/gpu/hip_gemm_impl.hpp b/src/targets/gpu/include/migraphx/gpu/hip_gemm_impl.hpp index 7f3fa907384..f26d594d8b5 100644 --- a/src/targets/gpu/include/migraphx/gpu/hip_gemm_impl.hpp +++ b/src/targets/gpu/include/migraphx/gpu/hip_gemm_impl.hpp @@ -68,6 +68,13 @@ int32_t hip_gemm_default_solution(context& ctx, const shape& output_shape, const std::vector& input_shapes); +size_t hip_gemm_workspace_size(context& ctx, + const shape& output_shape, + const std::vector& input_shapes, + float alpha, + float beta, + int32_t solution_idx); + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/include/migraphx/gpu/hipblaslt.hpp b/src/targets/gpu/include/migraphx/gpu/hipblaslt.hpp index 8b9ec2cef63..49d41bf4dcd 100644 --- a/src/targets/gpu/include/migraphx/gpu/hipblaslt.hpp +++ b/src/targets/gpu/include/migraphx/gpu/hipblaslt.hpp @@ -75,15 +75,20 @@ inline auto hipblaslt_invoke(F f, Ts... xs) return status; } +// Invoke a hipBLASLt call. If used to validate a call, set fatal_error = false to prevent +// throwing an exception on failure. template -auto hipblaslt_invoke(F f, Pack p, Ts... xs) +auto hipblaslt_invoke(F f, Pack p, Ts... xs, bool fatal_error = true) { return p([=](auto... ws) { auto status = f(ws..., xs...); if(status != HIPBLAS_STATUS_SUCCESS) { - MIGRAPHX_THROW("hipblaslt_invoke: hipBlasLt call failed with status " + - std::to_string(status)); + if(fatal_error) + { + MIGRAPHX_THROW("hipblaslt_invoke: hipBlasLt call failed with status " + + std::to_string(status)); + } } return status; }); diff --git a/src/targets/gpu/include/migraphx/gpu/miopen.hpp b/src/targets/gpu/include/migraphx/gpu/miopen.hpp index fb61103538d..87a561ad6f4 100644 --- a/src/targets/gpu/include/migraphx/gpu/miopen.hpp +++ b/src/targets/gpu/include/migraphx/gpu/miopen.hpp @@ -143,6 +143,8 @@ inline tensor_descriptor make_tensor(const migraphx::shape& os) d = miopenInt32; else if(s.type() == shape::int8_type) d = miopenInt8; + else if(s.type() == shape::bf16_type) + d = miopenBFloat16; else MIGRAPHX_THROW("MAKE_TENSOR: unsupported type"); miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data()); diff --git a/src/targets/gpu/jit/ck_gemm.cpp b/src/targets/gpu/jit/ck_gemm.cpp index 7cd20b3931d..392eaa0c67b 100644 --- a/src/targets/gpu/jit/ck_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm.cpp @@ -175,7 +175,7 @@ struct ck_gemm_compiler : compiler {"preamble", v.get("preamble", std::string{})}, {"kernel", options.kernel_name}}); - return compile_hip_code_object(src, options); + return compile_hip_code_object(ctx, src, options); } value create_settings(instruction_ref ins, const operation& op) const diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp index 5fe60372b94..693153d0982 100644 --- a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp @@ -175,7 +175,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler {"preamble", v.get("preamble", std::string{})}, {"kernel", options.kernel_name}}); - return compile_hip_code_object(src, options); + return compile_hip_code_object(ctx, src, options); } value create_settings(instruction_ref ins, const operation& op) const diff --git a/src/targets/gpu/jit/compute_attention_probabilities.cpp b/src/targets/gpu/jit/compute_attention_probabilities.cpp index 4bf03ff01d8..8a0c722078f 100644 --- a/src/targets/gpu/jit/compute_attention_probabilities.cpp +++ b/src/targets/gpu/jit/compute_attention_probabilities.cpp @@ -98,7 +98,7 @@ struct compute_attention_probabilities_compiler : compiler {"transformers", make_transformer_args(vec)}, {"preamble", v.get("preamble", std::string{})}, {"axis", std::to_string(concat_axis)}}); - return compile_hip_code_object(src, options); + return compile_hip_code_object(ctx, src, options); } compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const diff --git a/src/targets/gpu/jit/concat_past_present.cpp b/src/targets/gpu/jit/concat_past_present.cpp index 6d7f29b481c..b18d7010809 100644 --- a/src/targets/gpu/jit/concat_past_present.cpp +++ b/src/targets/gpu/jit/concat_past_present.cpp @@ -103,7 +103,7 @@ struct concat_past_present_compiler : compiler {"args", enum_params(inputs.size(), "private_p")}, {"gqa_params", gqa_params_str}, {"kernel", options.kernel_name}}); - return compile_hip_code_object(src, options); + return compile_hip_code_object(ctx, src, options); } compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const diff --git a/src/targets/gpu/jit/gather.cpp b/src/targets/gpu/jit/gather.cpp index 6409fa0d738..9dc17db0972 100644 --- a/src/targets/gpu/jit/gather.cpp +++ b/src/targets/gpu/jit/gather.cpp @@ -75,7 +75,7 @@ struct gather_compiler : compiler auto src = interpolate_string(gather_kernel, {{"axis", axis}}); - return compile_hip_code_object(src, options); + return compile_hip_code_object(ctx, src, options); } compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const diff --git a/src/targets/gpu/jit/gathernd.cpp b/src/targets/gpu/jit/gathernd.cpp index 1c39c29aa97..05a48f4e9cc 100644 --- a/src/targets/gpu/jit/gathernd.cpp +++ b/src/targets/gpu/jit/gathernd.cpp @@ -77,7 +77,7 @@ struct gathernd_compiler : compiler auto batch_dims = v.at("batch_dims").to(); options.emplace_param("-DBATCH_DIMS=" + std::to_string(batch_dims)); - return compile_hip_code_object(gathernd_kernel, options); + return compile_hip_code_object(ctx, gathernd_kernel, options); } compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const diff --git a/src/targets/gpu/jit/gqa_rotary_embedding.cpp b/src/targets/gpu/jit/gqa_rotary_embedding.cpp index 8aa3bcc6a7a..34061635280 100644 --- a/src/targets/gpu/jit/gqa_rotary_embedding.cpp +++ b/src/targets/gpu/jit/gqa_rotary_embedding.cpp @@ -98,7 +98,7 @@ struct gqa_rotary_embedding_compiler : compiler {"args", enum_params(inputs.size(), "private_p")}, {"gqa_params", gqa_params_str}, {"kernel", options.kernel_name}}); - return compile_hip_code_object(src, options); + return compile_hip_code_object(ctx, src, options); } compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const diff --git a/src/targets/gpu/jit/gqa_softmax.cpp b/src/targets/gpu/jit/gqa_softmax.cpp index e1a0910dddb..c1ff241ddeb 100644 --- a/src/targets/gpu/jit/gqa_softmax.cpp +++ b/src/targets/gpu/jit/gqa_softmax.cpp @@ -97,7 +97,7 @@ struct gqa_softmax_compiler : compiler {"args", enum_params(inputs.size(), "private_p")}, {"gqa_params", gqa_params_str}, {"kernel", options.kernel_name}}); - return compile_hip_code_object(src, options); + return compile_hip_code_object(ctx, src, options); } compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const diff --git a/src/targets/gpu/jit/layernorm.cpp b/src/targets/gpu/jit/layernorm.cpp index 4b55d35529d..09736031bbb 100644 --- a/src/targets/gpu/jit/layernorm.cpp +++ b/src/targets/gpu/jit/layernorm.cpp @@ -101,7 +101,7 @@ struct layernorm_compiler : compiler {"axis", to_string(axis)}, {"eps", to_string(eps)}}); - return compile_hip_code_object(src, options); + return compile_hip_code_object(ctx, src, options); } compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index 88b1594bc90..4893743c2bc 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -125,14 +125,56 @@ struct mlir_compiler : compiler return {std::vector{mco.cop}, [=](module& m, instruction_ref ins, const std::vector& ops) { std::vector inputs = ins->inputs(); + + // Tuple inputs not supported + assert(std::all_of(inputs.begin(), inputs.end() - 1, [](auto i) { + return i->get_shape().sub_shapes().empty(); + })); + + // Multiple output case (allocate ins will give a tuple) + std::vector flat_inputs(inputs); + bool multi_out = not flat_inputs.back()->get_shape().sub_shapes().empty(); + if(multi_out) + { + auto allocs = flat_inputs.back(); + flat_inputs.pop_back(); + auto sub_shape_idx = range(allocs->get_shape().sub_shapes().size()); + std::transform(sub_shape_idx.begin(), + sub_shape_idx.end(), + std::back_inserter(flat_inputs), + [&](int i) { + return m.insert_instruction( + ins, + migraphx::make_op("get_tuple_elem", {{"index", i}}), + allocs); + }); + } + std::vector tuple_replacements; + for(const auto i : range(mco.prefill_indices.size())) { auto prefilled_ins = m.insert_instruction( ins, migraphx::make_op("hip::fill", {{"value", mco.prefill_values[i]}}), - inputs[mco.prefill_indices[i]]); - replace(inputs, inputs[mco.prefill_indices[i]], prefilled_ins); + flat_inputs[mco.prefill_indices[i]]); + if(not multi_out or mco.prefill_indices[i] < inputs.size() - 1) + { + replace(inputs, inputs[mco.prefill_indices[i]], prefilled_ins); + } + else + { + tuple_replacements.push_back(prefilled_ins); + } } + + if(multi_out and not tuple_replacements.empty()) + { + // Add identity to make sure fill operations happen before kernel call + tuple_replacements.insert(tuple_replacements.begin(), inputs.back()); + inputs.back() = m.insert_instruction( + ins, migraphx::make_op("identity"), tuple_replacements); + } + auto mlir = insert_mlir(m, ins, any_cast(ops.front()), inputs); return m.replace_instruction(ins, mlir); }, @@ -212,7 +254,7 @@ struct mlir_compiler : compiler const operation&, bool exhaustive) const { - static const auto mxr_loc = string_value_of(MIGRAPHX_MLIR_DUMP_TO_MXR{}); + static const auto mxr_loc = string_value_of(MIGRAPHX_MLIR_DUMP_TO_MXR{}); static const auto mlir_loc = string_value_of(MIGRAPHX_MLIR_DUMP{}); auto shapes = to_shapes(ins->inputs()); diff --git a/src/targets/gpu/jit/pad.cpp b/src/targets/gpu/jit/pad.cpp index d216cb2dc74..9ea77ee996b 100644 --- a/src/targets/gpu/jit/pad.cpp +++ b/src/targets/gpu/jit/pad.cpp @@ -108,7 +108,7 @@ struct pad_compiler : compiler auto src = interpolate_string( pointwise_kernel, {{"pad_val", to_string(pad_val_string)}, {"offsets", to_string_range(roffsets)}}); - return compile_hip_code_object(src, options); + return compile_hip_code_object(ctx, src, options); } compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const diff --git a/src/targets/gpu/jit/pointwise.cpp b/src/targets/gpu/jit/pointwise.cpp index c8f7e5589c8..9e352888e1b 100644 --- a/src/targets/gpu/jit/pointwise.cpp +++ b/src/targets/gpu/jit/pointwise.cpp @@ -97,7 +97,7 @@ struct pointwise_compiler : compiler {"tiled", t.ntiles > 0 ? "true" : "false"}, {"noutputs", std::to_string(noutputs)}, {"preamble", v.get("preamble", std::string{})}}); - return compile_hip_code_object(src, options); + return compile_hip_code_object(ctx, src, options); } compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const diff --git a/src/targets/gpu/jit/pooling.cpp b/src/targets/gpu/jit/pooling.cpp index 2c02247e0bf..f245a226940 100644 --- a/src/targets/gpu/jit/pooling.cpp +++ b/src/targets/gpu/jit/pooling.cpp @@ -179,7 +179,7 @@ struct pooling_compiler : compiler {"stride", to_string_range(stride)}, {"padding", to_string_range(padding)}}); - return compile_hip_code_object(src, options); + return compile_hip_code_object(ctx, src, options); } compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index bcfdfe6a198..bdf7313f5f1 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -265,7 +265,7 @@ struct simple_reduce_compiler : compiler {"transformers", make_transformer_args(vec)}, {"preamble", v.get("preamble", std::string{})}}); options.emplace_param("-Wno-float-equal"); - return compile_hip_code_object(src, options); + return compile_hip_code_object(ctx, src, options); } compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const @@ -389,7 +389,7 @@ struct fused_reduce_compiler : compiler {"noutputs", std::to_string(noutputs)}, {"preamble", v.get("preamble", std::string{})}}); options.emplace_param("-Wno-float-equal"); - return compile_hip_code_object(src, options); + return compile_hip_code_object(ctx, src, options); } compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const diff --git a/src/targets/gpu/jit/roialign.cpp b/src/targets/gpu/jit/roialign.cpp index 456f93e23a6..aeaf7a85898 100644 --- a/src/targets/gpu/jit/roialign.cpp +++ b/src/targets/gpu/jit/roialign.cpp @@ -90,7 +90,7 @@ struct roialign_compiler : compiler // spatial_scale options.emplace_param("-DSPATIAL_SCALE=" + v.at("spatial_scale").to()); - return compile_hip_code_object(roialign_kernel, options); + return compile_hip_code_object(ctx, roialign_kernel, options); } compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const diff --git a/src/targets/gpu/jit/scatter.hpp b/src/targets/gpu/jit/scatter.hpp index 2902aa9c599..6fb95564760 100644 --- a/src/targets/gpu/jit/scatter.hpp +++ b/src/targets/gpu/jit/scatter.hpp @@ -51,7 +51,7 @@ struct scatter_compiler : compiler options.emplace_param("-DMIGRAPHX_ALLOW_ATOMIC_CAS=1"); const auto src = derived().make_interpolated_string(op); - return prepend_copy_data_to_output(compile_hip_code_object(src, options)); + return prepend_copy_data_to_output(compile_hip_code_object(ctx, src, options)); } // ONNX spec states the following for ScatterElements and ScatterND: diff --git a/src/targets/gpu/jit/softmax.cpp b/src/targets/gpu/jit/softmax.cpp index e64d61da033..d2e24a2335c 100644 --- a/src/targets/gpu/jit/softmax.cpp +++ b/src/targets/gpu/jit/softmax.cpp @@ -90,7 +90,7 @@ struct softmax_compiler : compiler softmax_kernel, {{"transformers", make_transformer_args(vec)}, {"axis", to_string(axis)}}); - return compile_hip_code_object(src, options); + return compile_hip_code_object(ctx, src, options); } compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const diff --git a/src/targets/gpu/jit/unpack_int4.cpp b/src/targets/gpu/jit/unpack_int4.cpp index f770d56d88f..68f2038b860 100644 --- a/src/targets/gpu/jit/unpack_int4.cpp +++ b/src/targets/gpu/jit/unpack_int4.cpp @@ -76,7 +76,7 @@ struct unpack_int4_compiler : compiler {"params", enum_params(options.inputs.size(), "void * private_p")}, {"args", enum_params(options.inputs.size(), "private_p")}, {"axis", std::to_string(v.at("axis").to())}}); - return compile_hip_code_object(src, options); + return compile_hip_code_object(ctx, src, options); } compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const diff --git a/src/targets/gpu/kernel.cpp b/src/targets/gpu/kernel.cpp index f43c0979da4..a7c79bded60 100644 --- a/src/targets/gpu/kernel.cpp +++ b/src/targets/gpu/kernel.cpp @@ -134,7 +134,7 @@ void kernel::launch(hipStream_t stream, hipEvent_t stop) const { assert(impl != nullptr); - void* kernargs = args.data(); + void* kernargs = reinterpret_cast(args.data()); std::size_t size = args.size() * sizeof(void*); launch_kernel(impl->fun, stream, global, local, kernargs, size, start, stop); diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp index c98395bbe10..e559658a004 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp @@ -23,15 +23,20 @@ #define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP #include +#include namespace migraphx { + template {} and is_trivially_copyable{})> -inline constexpr To bit_cast(From fr) noexcept +inline constexpr auto bit_cast(From fr) noexcept { - static_assert(sizeof(To) == sizeof(From)); - return __builtin_bit_cast(To, fr); + return vec_transform(fr)([](auto x) -> To { + static_assert(sizeof(To) == sizeof(decltype(x))); + return __builtin_bit_cast(To, x); + }); } + } // namespace migraphx #endif // MIGRAPHX_GUARD_KERNELS_BITCAST_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp index 1a263e98b2d..3e9d802611f 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp @@ -63,11 +63,11 @@ template struct overloaded : Fs... { using Fs::operator()...; - overloaded(Fs... fs) : Fs(fs)... {} + constexpr overloaded(Fs... fs) : Fs(fs)... {} }; template -overloaded overload(Fs... fs) +constexpr overloaded overload(Fs... fs) { return {fs...}; } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp index bf54bae3592..9c43f5d3bfb 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp @@ -50,8 +50,8 @@ inline __device__ __attribute__((const)) index_int compute_global_size() #ifdef MIGRAPHX_NGLOBAL return MIGRAPHX_NGLOBAL; #else - // This actualy works even when global is not divisible by local size. - // This doesnt actually do a multiplicatiosn. Instead it calls a device + // This actually works even when global is not divisible by local size. + // This doesnt actually do a multiplication. Instead it calls a device // function to get the global size, which is why it works. return blockDim.x * gridDim.x; // NOLINT #endif @@ -155,7 +155,7 @@ struct index return max_nlocal() / nlocal_subwave(); } - constexpr index_constant<__AMDGCN_WAVEFRONT_SIZE> nlocal_wave() const { return {}; } + constexpr index_constant nlocal_wave() const { return {}; } constexpr auto local_wave() const { return local % nlocal_wave(); } constexpr auto nwave() const { return max_nlocal() / nlocal_wave(); } constexpr auto wave() const { return local / nlocal_wave(); } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp index 2fa5c60060a..cadcc05f577 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp @@ -30,23 +30,72 @@ #include #include #include +#include +#include namespace migraphx { namespace math { -constexpr float as_float(migraphx::half x) { return x; } -constexpr float as_float(migraphx::fp8::fp8e4m3fnuz x) { return x; } -constexpr float as_float(migraphx::fp8::fp8e4m3fn x) { return x; } -constexpr float as_float(migraphx::fp8::fp8e5m2 x) { return x; } +template +constexpr auto as_float(T x) +{ + if constexpr(is_integral{}) + return x; + else + return float(x); +} template -constexpr T as_float(T x) +constexpr auto to_native(T x) { return x; } + +constexpr migraphx::half to_native(__half x) { return bit_cast(x); } + +template ())> +__device__ auto wrap(F f, T x, Ts... xs) +{ + if constexpr(is_integral{}) + { + return wrap(f, double(x), double(xs)...); + } + else if constexpr(is_callable{}) + { + return to_native(f(x, xs...)); + } + else + { + T result = f(as_float(x), as_float(xs)...); + return result; + } +} + } // namespace math +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH_LIFT_IMPL(type, ...) \ + [](type x, auto... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(x, xs...)) + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH_LIFT(...) MIGRAPHX_DEVICE_MATH_LIFT_IMPL(__VA_ARGS__) + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH_PARSE(x) x, + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH_EACH(f) MIGRAPHX_DEVICE_MATH_LIFT(MIGRAPHX_DEVICE_MATH_PARSE f) + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH_WRAP(name, ...) \ + namespace math { \ + inline static constexpr auto wrap_##name = \ + overload(MIGRAPHX_PP_TRANSFORM_ARGS(MIGRAPHX_DEVICE_MATH_EACH, __VA_ARGS__)); \ + } \ + template \ + auto __device__ name(Ts... xs) MIGRAPHX_RETURNS(math::wrap(math::wrap_##name, xs...)) + // NOLINTNEXTLINE #define MIGRAPHX_DEVICE_MATH(name, fname) \ template ())> \ @@ -72,165 +121,47 @@ constexpr T as_float(T x) #define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \ inline auto __device__ name(type x, type y) -> type { return fname(x, y); } -// NOLINTNEXTLINE -#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \ - template ())> \ - auto __device__ name(migraphx::half x, Ts... xs) \ - MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...)) - -// NOLINTNEXTLINE -#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \ - template ())> \ - auto __device__ name(migraphx::fp8::fp8e4m3fnuz x, Ts... xs) MIGRAPHX_RETURNS( \ - migraphx::fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...))) \ - \ - template ())> \ - auto __device__ name(migraphx::fp8::fp8e4m3fn x, Ts... xs) MIGRAPHX_RETURNS( \ - migraphx::fp8::fp8e4m3fn(fname(math::as_float(x), math::as_float(xs)...))) \ - \ - template ())> \ - auto __device__ name(migraphx::fp8::fp8e5m2 x, Ts... xs) MIGRAPHX_RETURNS( \ - migraphx::fp8::fp8e5m2(fname(math::as_float(x), math::as_float(xs)...))) - // Template with two overloads for math functions, one for half2 type and one for more generic // vectorization where N is 4 or another even number. - // NOLINTNEXTLINE -#define MIGRAPHX_DEVICE_MATH_HALF2(name, fname) \ - template \ - auto __device__ name(migraphx::vec x, Ts... xs) \ - MIGRAPHX_RETURNS(migraphx::vec{fname(x, xs...)}); \ - template 2))> \ - auto __device__ name(migraphx::vec x, Ts... xs) \ - { \ - return vec_packed_transform<2>(x, xs...)( \ - [](auto... ys) -> migraphx::vec { return fname(ys...); }); \ +#define MIGRAPHX_DEVICE_MATH_VEC2(type, name, fname) \ + template \ + auto __device__ name(migraphx::vec x, Ts... xs) \ + MIGRAPHX_RETURNS(migraphx::vec{fname(x, xs...)}); \ + template 2))> \ + auto __device__ name(migraphx::vec x, Ts... xs) \ + { \ + return vec_packed_transform<2>(x, xs...)( \ + [](auto... ys) -> migraphx::vec { return fname(ys...); }); \ } -MIGRAPHX_DEVICE_MATH(abs, ::abs) -MIGRAPHX_DEVICE_MATH(acos, ::acos) -MIGRAPHX_DEVICE_MATH(acosh, ::acosh) -MIGRAPHX_DEVICE_MATH(asin, ::asin) -MIGRAPHX_DEVICE_MATH(asinh, ::asinh) -MIGRAPHX_DEVICE_MATH(atan, ::atan) -MIGRAPHX_DEVICE_MATH(atanh, ::atanh) -MIGRAPHX_DEVICE_MATH(ceil, ::ceil) -MIGRAPHX_DEVICE_MATH(cos, ::cos) -MIGRAPHX_DEVICE_MATH(cosh, ::cosh) -MIGRAPHX_DEVICE_MATH(erf, ::erf) -MIGRAPHX_DEVICE_MATH(exp, ::exp) -MIGRAPHX_DEVICE_MATH(floor, ::floor) -MIGRAPHX_DEVICE_MATH(isnan, ::isnan) -MIGRAPHX_DEVICE_MATH(isinf, ::isinf) -MIGRAPHX_DEVICE_MATH(log, ::log) -MIGRAPHX_DEVICE_MATH(log2, ::log2) -MIGRAPHX_DEVICE_MATH(nearbyint, ::nearbyint) -MIGRAPHX_DEVICE_MATH(pow, ::pow) -MIGRAPHX_DEVICE_MATH(remainder, ::remainder) -MIGRAPHX_DEVICE_MATH(round, ::round) -MIGRAPHX_DEVICE_MATH(rsqrt, ::rsqrt) -MIGRAPHX_DEVICE_MATH(sin, ::sin) -MIGRAPHX_DEVICE_MATH(sinh, ::sinh) -MIGRAPHX_DEVICE_MATH(sqrt, ::sqrt) -MIGRAPHX_DEVICE_MATH(tan, ::tan) -MIGRAPHX_DEVICE_MATH(tanh, ::tanh) -MIGRAPHX_DEVICE_MATH(fmod, ::fmod) - -// Float overloads -MIGRAPHX_DEVICE_MATH_FOR(float, acos, ::acosf) -MIGRAPHX_DEVICE_MATH_FOR(float, acosh, ::acoshf) -MIGRAPHX_DEVICE_MATH_FOR(float, asin, ::asinf) -MIGRAPHX_DEVICE_MATH_FOR(float, asinh, ::asinhf) -MIGRAPHX_DEVICE_MATH_FOR(float, atan, ::atanf) -MIGRAPHX_DEVICE_MATH_FOR(float, atanh, ::atanhf) -MIGRAPHX_DEVICE_MATH_FOR(float, cos, ::cosf) -MIGRAPHX_DEVICE_MATH_FOR(float, cosh, ::coshf) -MIGRAPHX_DEVICE_MATH_FOR(float, rsqrt, ::rsqrtf) -MIGRAPHX_DEVICE_MATH_FOR(float, sin, ::sinf) -MIGRAPHX_DEVICE_MATH_FOR(float, sinh, ::sinhf) -MIGRAPHX_DEVICE_MATH_FOR(float, tan, ::tanf) -MIGRAPHX_DEVICE_MATH_FOR(float, tanh, ::tanhf) -MIGRAPHX_DEVICE_MATH_FOR(float, fmod, ::fmodf) - -// Builtin half functions -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, ceil, ::hceil) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, cos, ::hcos) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, exp, ::hexp) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isinf, ::__hisinf) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isnan, ::__hisnan) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log2, ::hlog2) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt) - -// Use float to compute half overload -MIGRAPHX_DEVICE_MATH_HALF(acos, ::acos) -MIGRAPHX_DEVICE_MATH_HALF(acosh, ::acosh) -MIGRAPHX_DEVICE_MATH_HALF(asin, ::asin) -MIGRAPHX_DEVICE_MATH_HALF(asinh, ::asinh) -MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan) -MIGRAPHX_DEVICE_MATH_HALF(atanh, ::atanh) -MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh) -MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf) -MIGRAPHX_DEVICE_MATH_HALF(nearbyint, ::nearbyint) -MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow) -MIGRAPHX_DEVICE_MATH_HALF(remainder, ::remainder) -MIGRAPHX_DEVICE_MATH_HALF(round, ::round) -MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh) -MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan) -MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh) -MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod) - -// use float to compute fp8 overload -MIGRAPHX_DEVICE_MATH_FP8(abs, ::abs) -MIGRAPHX_DEVICE_MATH_FP8(acos, ::acos) -MIGRAPHX_DEVICE_MATH_FP8(acosh, ::acosh) -MIGRAPHX_DEVICE_MATH_FP8(asin, ::asin) -MIGRAPHX_DEVICE_MATH_FP8(asinh, ::asinh) -MIGRAPHX_DEVICE_MATH_FP8(atan, ::atan) -MIGRAPHX_DEVICE_MATH_FP8(atanh, ::atanh) -MIGRAPHX_DEVICE_MATH_FP8(ceil, ::ceil) -MIGRAPHX_DEVICE_MATH_FP8(cos, ::cos) -MIGRAPHX_DEVICE_MATH_FP8(cosh, ::cosh) -MIGRAPHX_DEVICE_MATH_FP8(erf, ::erf) -MIGRAPHX_DEVICE_MATH_FP8(exp, ::exp) -MIGRAPHX_DEVICE_MATH_FP8(floor, ::floor) -MIGRAPHX_DEVICE_MATH_FP8(isnan, ::isnan) -MIGRAPHX_DEVICE_MATH_FP8(log, ::log) -MIGRAPHX_DEVICE_MATH_FP8(log2, ::log2) -MIGRAPHX_DEVICE_MATH_FP8(pow, ::pow) -MIGRAPHX_DEVICE_MATH_FP8(remainder, ::remainder) -MIGRAPHX_DEVICE_MATH_FP8(round, ::round) -MIGRAPHX_DEVICE_MATH_FP8(rsqrt, ::rsqrt) -MIGRAPHX_DEVICE_MATH_FP8(sin, ::sin) -MIGRAPHX_DEVICE_MATH_FP8(sinh, ::sinh) -MIGRAPHX_DEVICE_MATH_FP8(sqrt, ::sqrt) -MIGRAPHX_DEVICE_MATH_FP8(tan, ::tan) -MIGRAPHX_DEVICE_MATH_FP8(tanh, ::tanh) -MIGRAPHX_DEVICE_MATH_FP8(fmod, ::fmod) - -// Map math functions to hip half2 functions -// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats -// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names -// Most but not all of these math ops have operators of the same names. -MIGRAPHX_DEVICE_MATH_HALF2(abs, ::__habs2) -MIGRAPHX_DEVICE_MATH_HALF2(ceil, ::h2ceil) -MIGRAPHX_DEVICE_MATH_HALF2(cos, ::h2cos) -MIGRAPHX_DEVICE_MATH_HALF2(exp, ::h2exp) -MIGRAPHX_DEVICE_MATH_HALF2(exp10, ::h2exp10) -MIGRAPHX_DEVICE_MATH_HALF2(exp2, ::h2exp2) -MIGRAPHX_DEVICE_MATH_HALF2(floor, ::h2floor) -MIGRAPHX_DEVICE_MATH_HALF2(isinf, ::__hisinf2) -MIGRAPHX_DEVICE_MATH_HALF2(isnan, ::__hisnan2) -MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log) -MIGRAPHX_DEVICE_MATH_HALF2(log10, ::h2log10) -MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2) -MIGRAPHX_DEVICE_MATH_HALF2(rsqrt, ::h2rsqrt) -MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin) -MIGRAPHX_DEVICE_MATH_HALF2(sqrt, ::h2sqrt) +MIGRAPHX_DEVICE_MATH_WRAP(acos, (double)::acos, (float)::acosf); +MIGRAPHX_DEVICE_MATH_WRAP(acosh, (double)::acosh, (float)::acoshf); +MIGRAPHX_DEVICE_MATH_WRAP(asin, (double)::asin, (float)::asinf); +MIGRAPHX_DEVICE_MATH_WRAP(asinh, (double)::asinh, (float)::asinh); +MIGRAPHX_DEVICE_MATH_WRAP(atan, (double)::atan, (float)::atan); +MIGRAPHX_DEVICE_MATH_WRAP(atanh, (double)::atanh, (float)::atanh); +MIGRAPHX_DEVICE_MATH_WRAP(ceil, (double)::ceil, (float)::ceilf, (half)::hceil); +MIGRAPHX_DEVICE_MATH_WRAP(cos, (double)::cos, (float)::cosf, (half)::hcos); +MIGRAPHX_DEVICE_MATH_WRAP(cosh, (double)::cosh, (float)::coshf); +MIGRAPHX_DEVICE_MATH_WRAP(erf, (double)::erf, (float)::erff); +MIGRAPHX_DEVICE_MATH_WRAP(exp, (double)::exp, (float)::expf, (half)::hexp); +MIGRAPHX_DEVICE_MATH_WRAP(floor, (double)::floor, (float)::floorf, (half)::hfloor); +MIGRAPHX_DEVICE_MATH_WRAP(isnan, (double)::isnan, (float)::isnan, (half)::__hisnan); +MIGRAPHX_DEVICE_MATH_WRAP(isinf, (double)::isinf, (float)::isinf, (half)::__hisinf); +MIGRAPHX_DEVICE_MATH_WRAP(log, (double)::log, (float)::logf, (half)::hlog); +MIGRAPHX_DEVICE_MATH_WRAP(log2, (double)::log2, (float)::log2f, (half)::hlog2); +MIGRAPHX_DEVICE_MATH_WRAP(nearbyint, (double)::nearbyint, (float)::nearbyintf); +MIGRAPHX_DEVICE_MATH_WRAP(pow, (double)::pow, (float)::powf); +MIGRAPHX_DEVICE_MATH_WRAP(remainder, (double)::remainder, (float)::remainderf); +MIGRAPHX_DEVICE_MATH_WRAP(round, (double)::round, (float)::roundf); +MIGRAPHX_DEVICE_MATH_WRAP(rsqrt, (double)::rsqrt, (float)::rsqrtf, (half)::hrsqrt); +MIGRAPHX_DEVICE_MATH_WRAP(sin, (double)::sin, (float)::sinf, (half)::hsin); +MIGRAPHX_DEVICE_MATH_WRAP(sinh, (double)::sinh, (float)::sinhf); +MIGRAPHX_DEVICE_MATH_WRAP(sqrt, (double)::sqrt, (float)::sqrtf, (half)::hsqrt); +MIGRAPHX_DEVICE_MATH_WRAP(tan, (double)::tan, (float)::tanf); +MIGRAPHX_DEVICE_MATH_WRAP(tanh, (double)::tanh, (float)::tanhf); +MIGRAPHX_DEVICE_MATH_WRAP(fmod, (double)::fmod, (float)::fmodf); template constexpr auto where(bool cond, const T& a, const U& b) @@ -238,13 +169,22 @@ constexpr auto where(bool cond, const T& a, const U& b) return cond ? a : b; } -MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max) -MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min) +MIGRAPHX_DEVICE_MATH_FOR(float, abs, ::abs) +MIGRAPHX_DEVICE_MATH_FOR(double, abs, ::abs) +MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs) +MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::fmaxf) +MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::fminf) MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max) MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min) MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::__hmax) MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::__hmin) +template () and is_integral{})> +constexpr auto abs(const T& a) +{ + return where(a < 0, -a, a); +} + template ())> constexpr auto max(const T& a, const T& b) { @@ -317,6 +257,26 @@ MIGRAPHX_DEVICE_MATH_VEC(tan) MIGRAPHX_DEVICE_MATH_VEC(tanh) MIGRAPHX_DEVICE_MATH_VEC(where) +// Map math functions to hip half2 functions +// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats +// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names +// Most but not all of these math ops have operators of the same names. +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, abs, ::__habs2) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, ceil, ::h2ceil) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, cos, ::h2cos) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, exp, ::h2exp) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, exp10, ::h2exp10) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, exp2, ::h2exp2) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, floor, ::h2floor) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, isinf, ::__hisinf2) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, isnan, ::__hisnan2) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, log, ::h2log) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, log10, ::h2log10) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, log2, ::h2log2) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, rsqrt, ::h2rsqrt) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, sin, ::h2sin) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, sqrt, ::h2sqrt) + template constexpr auto convert(U v) { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/pp.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/pp.hpp index 272e0ca0d10..89b38ac247b 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/pp.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/pp.hpp @@ -24,11 +24,41 @@ #ifndef MIGRAPHX_GUARD_KERNELS_PP_HPP #define MIGRAPHX_GUARD_KERNELS_PP_HPP +// NOLINTBEGIN(*-macro-to-enum) + #define MIGRAPHX_PP_PRIMITIVE_CAT(x, y) x##y #define MIGRAPHX_PP_CAT(x, y) MIGRAPHX_PP_PRIMITIVE_CAT(x, y) #define MIGRAPHX_PP_EAT(...) #define MIGRAPHX_PP_EXPAND(...) __VA_ARGS__ +#define MIGRAPHX_PP_COMMA(...) , + +#define MIGRAPHX_PP_IIF(c) MIGRAPHX_PP_PRIMITIVE_CAT(MIGRAPHX_PP_IIF_, c) +#define MIGRAPHX_PP_IIF_0(t, ...) __VA_ARGS__ +#define MIGRAPHX_PP_IIF_1(t, ...) t + +#define MIGRAPHX_PP_COMPL(b) MIGRAPHX_PP_PRIMITIVE_CAT(MIGRAPHX_PP_COMPL_, b) +#define MIGRAPHX_PP_COMPL_0 1 +#define MIGRAPHX_PP_COMPL_1 0 + +#define MIGRAPHX_PP_BITAND(x) MIGRAPHX_PP_PRIMITIVE_CAT(MIGRAPHX_PP_BITAND_, x) +#define MIGRAPHX_PP_BITAND_0(y) 0 +#define MIGRAPHX_PP_BITAND_1(y) y + +#define MIGRAPHX_PP_CHECK(...) MIGRAPHX_PP_CHECK_N(__VA_ARGS__, 0, ) +#define MIGRAPHX_PP_CHECK_N(x, n, ...) n +#define MIGRAPHX_PP_PROBE(x) x, 1, + +#define MIGRAPHX_PP_IS_PAREN(x) MIGRAPHX_PP_CHECK(MIGRAPHX_PP_IS_PAREN_PROBE x) +#define MIGRAPHX_PP_IS_PAREN_PROBE(...) MIGRAPHX_PP_PROBE(~) + +#define MIGRAPHX_PP_PRIMITIVE_IS_EMPTY(x) \ + MIGRAPHX_PP_CHECK(MIGRAPHX_PP_PRIMITIVE_IS_EMPTY_PROBE x()) +#define MIGRAPHX_PP_PRIMITIVE_IS_EMPTY_PROBE(...) MIGRAPHX_PP_PROBE(~) + +#define MIGRAPHX_PP_IS_EMPTY_ARG(x) \ + MIGRAPHX_PP_BITAND(MIGRAPHX_PP_COMPL(MIGRAPHX_PP_IS_PAREN(x))) \ + (MIGRAPHX_PP_PRIMITIVE_IS_EMPTY(x)) #define MIGRAPHX_PP_REPEAT0(m, ...) m(0, __VA_ARGS__) #define MIGRAPHX_PP_REPEAT1(m, ...) MIGRAPHX_PP_REPEAT0(m, __VA_ARGS__) m(1, __VA_ARGS__) @@ -45,4 +75,55 @@ #define MIGRAPHX_PP_REPEAT(n, m, ...) \ MIGRAPHX_PP_PRIMITIVE_CAT(MIGRAPHX_PP_REPEAT, n)(m, __VA_ARGS__) +#define MIGRAPHX_PP_RES_ARGS() , , , , , , , , , , , , , , , + +#define MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARGS(...) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARGS_IMPL(__VA_ARGS__) + +#define MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARGS_IMPL( \ + m, delim, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15, ...) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x0) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x1) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x1) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x2) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x2) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x3) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x3) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x4) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x4) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x5) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x5) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x6) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x6) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x7) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x7) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x8) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x8) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x9) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x9) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x10) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x10) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x11) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x11) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x12) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x12) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x13) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x13) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x14) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x14) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x15) MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x15) + +#define MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x) \ + MIGRAPHX_PP_IIF(MIGRAPHX_PP_IS_EMPTY_ARG(x))(MIGRAPHX_PP_EAT, m)(x) + +#define MIGRAPHX_PP_EACH_ARGS(m, ...) \ + MIGRAPHX_PP_EXPAND(MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARGS( \ + m, MIGRAPHX_PP_EAT, __VA_ARGS__, MIGRAPHX_PP_RES_ARGS())) + +#define MIGRAPHX_PP_TRANSFORM_ARGS(m, ...) \ + MIGRAPHX_PP_EXPAND(MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARGS( \ + m, MIGRAPHX_PP_COMMA, __VA_ARGS__, MIGRAPHX_PP_RES_ARGS())) + +// NOLINTEND(*-macro-to-enum) + #endif // MIGRAPHX_GUARD_KERNELS_PP_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp index 0b32227cd0b..76150cbfc54 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp @@ -39,7 +39,7 @@ namespace migraphx { template __device__ void dpp_reduce(T& in, Op op) { - static_assert(SubWaveSize <= __AMDGCN_WAVEFRONT_SIZE, "Too large subwave size"); + static_assert(SubWaveSize <= MIGRAPHX_WAVEFRONTSIZE, "Too large subwave size"); static_assert(is_power_of_2(SubWaveSize), "SubWaveSize is not a power of 2"); if constexpr(SubWaveSize > 1) { @@ -61,7 +61,7 @@ __device__ void dpp_reduce(T& in, Op op) auto out = dpp_mov(in); in = op(in, out); } -#if __AMDGCN_WAVEFRONT_SIZE == 32 +#if MIGRAPHX_WAVEFRONTSIZE == 32 if constexpr(SubWaveSize > 16) { auto out = dpp_swizzle<0x1e0>(in); @@ -113,7 +113,7 @@ __device__ void dpp_reduce(T& in, Op op) __VA_ARGS__ \ } -#if __AMDGCN_WAVEFRONT_SIZE == 64 +#if MIGRAPHX_WAVEFRONTSIZE == 64 #define MIGRAPHX_DPP_REDUCE_SWIZZLE(x, f) (void)f; #else #define MIGRAPHX_DPP_REDUCE_SWIZZLE(x, f) \ @@ -121,25 +121,25 @@ __device__ void dpp_reduce(T& in, Op op) x = f(x, y); #endif -#define MIGRAPHX_DPP_REDUCE_ASM_FUN(type, op, ins) \ - template \ - __device__ inline void dpp_reduce(type& x, op f) \ - { \ - if constexpr(SubWaveSize == 2) \ - MIGRAPHX_DPP_REDUCE_ASM(0, x, ins, ); \ - if constexpr(SubWaveSize == 4) \ - MIGRAPHX_DPP_REDUCE_ASM(1, x, ins, ); \ - if constexpr(SubWaveSize == 8) \ - MIGRAPHX_DPP_REDUCE_ASM(2, x, ins, ); \ - if constexpr(SubWaveSize == 16) \ - MIGRAPHX_DPP_REDUCE_ASM(3, x, ins, ); \ - if constexpr(SubWaveSize == 32) \ - MIGRAPHX_DPP_REDUCE_ASM(MIGRAPHX_DPP_IF_64(__AMDGCN_WAVEFRONT_SIZE)(4, 3), \ - x, \ - ins, \ - MIGRAPHX_DPP_REDUCE_SWIZZLE(x, f)); \ - MIGRAPHX_DPP_WHEN_64(__AMDGCN_WAVEFRONT_SIZE) \ - (if constexpr(SubWaveSize == 64) MIGRAPHX_DPP_REDUCE_ASM(5, x, ins, )); \ +#define MIGRAPHX_DPP_REDUCE_ASM_FUN(type, op, ins) \ + template \ + __device__ inline void dpp_reduce(type& x, op f) \ + { \ + if constexpr(SubWaveSize == 2) \ + MIGRAPHX_DPP_REDUCE_ASM(0, x, ins, ); \ + if constexpr(SubWaveSize == 4) \ + MIGRAPHX_DPP_REDUCE_ASM(1, x, ins, ); \ + if constexpr(SubWaveSize == 8) \ + MIGRAPHX_DPP_REDUCE_ASM(2, x, ins, ); \ + if constexpr(SubWaveSize == 16) \ + MIGRAPHX_DPP_REDUCE_ASM(3, x, ins, ); \ + if constexpr(SubWaveSize == 32) \ + MIGRAPHX_DPP_REDUCE_ASM(MIGRAPHX_DPP_IF_64(MIGRAPHX_WAVEFRONTSIZE)(4, 3), \ + x, \ + ins, \ + MIGRAPHX_DPP_REDUCE_SWIZZLE(x, f)); \ + MIGRAPHX_DPP_WHEN_64(MIGRAPHX_WAVEFRONTSIZE) \ + (if constexpr(SubWaveSize == 64) MIGRAPHX_DPP_REDUCE_ASM(5, x, ins, )); \ } #endif @@ -170,7 +170,7 @@ MIGRAPHX_DPP_REDUCE(op::min, v_min, _i) template __device__ void dpp_reduce(T& in, Op op) { - dpp_reduce<__AMDGCN_WAVEFRONT_SIZE>(in, op); + dpp_reduce(in, op); } template @@ -188,7 +188,7 @@ __device__ auto subwave_reduce(index idx, Op op, T init, Index n, F f) template __device__ auto wave_reduce(index idx, Op op, T init, Index n, F f) { - return subwave_reduce<__AMDGCN_WAVEFRONT_SIZE>(idx, op, init, n, f); + return subwave_reduce(idx, op, init, n, f); } template @@ -196,10 +196,10 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f) { MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); #ifdef MIGRAPHX_HAS_CONST_LOCAL - if constexpr(decltype(idx.nlocal()){} == __AMDGCN_WAVEFRONT_SIZE) + if constexpr(decltype(idx.nlocal()){} == MIGRAPHX_WAVEFRONTSIZE) return wave_reduce(idx, op, init, n, f); #endif - constexpr index_int lanes_per_thread = __AMDGCN_WAVEFRONT_SIZE; + constexpr index_int lanes_per_thread = MIGRAPHX_WAVEFRONTSIZE; using type = decltype(index::invoke_loop(f, 0, _c<0>)); __shared__ type buffer[idx.max_nlocal() / lanes_per_thread]; auto x = type(init); @@ -661,7 +661,7 @@ struct subwave } }; -using wave = subwave<__AMDGCN_WAVEFRONT_SIZE>; +using wave = subwave; struct lane { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp index 60f68029304..1b0d1343ea2 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp @@ -29,6 +29,9 @@ namespace migraphx { +template +using void_t = void; + template U private_declval(int); @@ -38,6 +41,19 @@ T private_declval(long); template auto declval() noexcept -> decltype(private_declval(0)); +template +struct is_callable_impl : false_type +{ +}; + +template +struct is_callable_impl()(declval()...))>, F, Ts...> : true_type +{ +}; + +template +using is_callable = is_callable_impl; + template struct type_identity { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp index 27f6303e6de..c88343ce16d 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp @@ -28,7 +28,7 @@ namespace migraphx { -#if defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS) and defined(MIGRAPHX_USE_HIPRTC) +#if defined(MIGRAPHX_USE_HIPRTC) using int8_t = signed char; using uint8_t = unsigned char; using int16_t = signed short; @@ -76,6 +76,7 @@ using vec = T __attribute__((ext_vector_type(N))); using half = _Float16; using half2 = migraphx::vec; +using bf16 = __bf16; } // namespace migraphx diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index 94df8db1856..adba5466135 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -251,14 +251,6 @@ struct miopen_apply apply_map.emplace(name, [=](instruction_ref ins) { std::vector refs = ins->inputs(); assert(refs.size() == 2); -#if MIGRAPHX_USE_HIPBLASLT - if(enabled(MIGRAPHX_ENABLE_HIPBLASLT_GEMM{})) - { - shape workspace_shape{shape::uint8_type, {hipblaslt_workspace_size}}; - auto workspace = insert_allocation(ins, workspace_shape); - refs.push_back(workspace); - } -#endif auto output = insert_allocation(ins, ins->get_shape()); refs.push_back(output); #if MIGRAPHX_USE_HIPBLASLT @@ -269,7 +261,18 @@ struct miopen_apply ins, rocblas_gemm{Op{}, 1, 0, compute_fp32}, refs); #if MIGRAPHX_USE_HIPBLASLT } - return mod->replace_instruction(ins, hip_gemm{Op{}, 1, 0}, refs); + std::string op_name = "gpu::hip_gemm"; + if(contains(name, "quant_")) + { + op_name = "gpu::hip_quant_gemm"; + } + operation gemm_op = make_op(op_name); + return mod->replace_instruction( + ins, + make_op("gpu::hipblaslt_op", {{"op", to_value(gemm_op)}}), + ins->inputs().at(0), + ins->inputs().at(1), + output); #endif }); } @@ -322,7 +325,8 @@ struct miopen_apply static bool use_miopen_pooling(instruction_ref ins) { - if(enabled(MIGRAPHX_DISABLE_MIOPEN_POOLING{})) + if(enabled(MIGRAPHX_DISABLE_MIOPEN_POOLING{}) or + not contains({shape::float_type, shape::half_type}, ins->get_shape().type())) return false; auto&& op = ins->get_operator(); auto op_val = op.to_value(); diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index e96c6a5d7dd..61e0325ac96 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -312,8 +312,12 @@ struct mlir_program result = mlirF32TypeGet(ctx.get()); else if(as.type_enum() == shape::half_type) result = mlirF16TypeGet(ctx.get()); + else if(as.type_enum() == shape::bf16_type) + result = mlirBF16TypeGet(ctx.get()); else if(as.type_enum() == shape::fp8e4m3fnuz_type) result = mlirFloat8E4M3FNUZTypeGet(ctx.get()); + else if(as.type_enum() == shape::fp8e5m2fnuz_type) + result = mlirFloat8E5M2FNUZTypeGet(ctx.get()); else if(as.type_enum() == shape::fp8e4m3fn_type) result = mlirFloat8E4M3FNTypeGet(ctx.get()); else if(as.type_enum() == shape::fp8e5m2_type) @@ -322,14 +326,14 @@ struct mlir_program result = mlirF64TypeGet(ctx.get()); else if(as.is_integral()) { - // Note: rocMLIR use signless integer type for tensors types. This - // will translate to signed implementation for current supported - // operations. if(as.is_unsigned()) { - MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum())); + result = mlirIntegerTypeUnsignedGet(ctx.get(), as.size() * 8); + } + else + { + result = mlirIntegerTypeSignedGet(ctx.get(), as.size() * 8); } - result = mlirIntegerTypeGet(ctx.get(), as.size() * 8); // number of bits } else MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum())); @@ -442,15 +446,15 @@ struct mlir_program } using attribute_t = std::variant, - MlirType, - MlirAttribute>; + std::uint64_t, + unsigned char, + bool, + double, + std::string, + value, + std::vector, + MlirType, + MlirAttribute>; using named_attribute_t = std::pair; MlirNamedAttribute name_attribute(const named_attribute_t& na) const @@ -718,10 +722,6 @@ struct mlir_program literal r = ins->get_literal(); auto sh = ins->get_shape(); - // mlir works only with signed types. change uint4 to (int4 + unsigned-flag) - if(shape::is_unsigned(sh.type()) and ins->outputs()[0]->name() == "unpack_int4") - sh = ins->get_shape().with_type(shape::int8_type); - MlirType shaped_type = make_mlir_shaped(sh); MlirType tensor_type = rocmlirMIXRShapedTypeAsTensor(shaped_type); MlirAttribute mlir_value_attr = @@ -729,13 +729,6 @@ struct mlir_program ops.add_attributes({{"value", mlir_value_attr}}); } - if(ins->name() == "unpack_int4") - { - auto sh = get_shape(ins); - ops.add_attributes( - {{"isUnsigned", shape::is_unsigned(sh.type())}}); // flag for uint4 - } - if(ins->name() == "convolution" or ins->name() == "dot") { pp = @@ -1164,7 +1157,7 @@ mlir_code_object compile_mlir(const context& migraphx_ctx, const std::lock_guard lock(mutex); std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl; } - auto co = mp.compile(solution); + auto co = mp.compile(solution); co.expected_inputs = in_shapes; auto out_shapes = m.get_output_shapes(); @@ -1257,7 +1250,7 @@ void dump_mlir_to_mxr(module m, sizes.insert(sizes.end(), ins->inputs().begin(), ins->inputs().end()); } auto name = compute_dump_name(m, ".mxr"); - auto f = location / name; + auto f = location / name; std::cout << "Dumping MXR file to: " << f << std::endl; save(program{std::move(m)}, f.string()); } diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 3437651f056..ad98fb680fe 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -35,7 +35,7 @@ #include #include #include -#include +#include #include #include #include @@ -56,6 +56,7 @@ #include #include #include +#include #include #include #include @@ -76,11 +77,11 @@ inline namespace MIGRAPHX_INLINE_NS { namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_SPLIT_REDUCE) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) #ifndef _WIN32 MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) #endif +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPBLASLT_GEMM) std::vector target::get_passes(migraphx::context& gctx, const compile_options& options) const { @@ -90,6 +91,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti std::set unsupported_types(shape::types().begin(), shape::types().end()); unsupported_types.erase(shape::type_t::float_type); unsupported_types.erase(shape::type_t::fp8e4m3fnuz_type); + unsupported_types.erase(shape::type_t::fp8e5m2fnuz_type); unsupported_types.erase(shape::type_t::fp8e4m3fn_type); unsupported_types.erase(shape::type_t::fp8e5m2_type); unsupported_types.erase(shape::type_t::half_type); @@ -98,12 +100,13 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::int32_type); unsupported_types.erase(shape::type_t::tuple_type); + unsupported_types.erase(shape::type_t::bf16_type); // whiltelist supported Ops for the FP8 types // different between fp8e4m3fnuz and OCP types because rocBLAS only has // support for fp8e4m3fnuz std::set unsupported_fp8e4m3fnuz_ops = {}; - if(not gpu::rocblas_fp8_available()) + if(not enabled(MIGRAPHX_ENABLE_HIPBLASLT_GEMM{}) and not gpu::rocblas_fp8_available()) { unsupported_fp8e4m3fnuz_ops.insert("dot"); unsupported_fp8e4m3fnuz_ops.insert("quant_dot"); @@ -128,10 +131,21 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti unsupported_fp8e4m3fnuz_ops.insert("argmax"); unsupported_fp8e4m3fnuz_ops.insert("argmin"); + std::set unsupported_fp8e5m2fnuz_ops = unsupported_fp8e4m3fnuz_ops; + // disable gemm for fp8e5m2fnuz if rocBLAS is being used + if(not enabled(MIGRAPHX_ENABLE_HIPBLASLT_GEMM{})) + { + unsupported_fp8e5m2fnuz_ops.insert("dot"); + unsupported_fp8e5m2fnuz_ops.insert("quant_dot"); + } + std::set unsupported_fp8ocp_ops = {}; - // TODO update with hipBLASLt support - unsupported_fp8ocp_ops.insert("dot"); - unsupported_fp8ocp_ops.insert("quant_dot"); + // TODO: remove this when the flag is removed + if(not enabled(MIGRAPHX_ENABLE_HIPBLASLT_GEMM{})) + { + unsupported_fp8ocp_ops.insert("dot"); + unsupported_fp8ocp_ops.insert("quant_dot"); + } #if MIGRAPHX_USE_MIOPEN // MIOpen doesn't have support for fp8 pooling yet. unsupported_fp8ocp_ops.insert("pooling"); @@ -140,6 +154,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti { unsupported_fp8ocp_ops.insert("convolution"); unsupported_fp8ocp_ops.insert("quant_convolution"); + unsupported_fp8ocp_ops.insert("dot"); + unsupported_fp8ocp_ops.insert("quant_dot"); } // add all device kernels unsupported_fp8ocp_ops.insert("logsoftmax"); @@ -182,11 +198,12 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, rewrite_gelu{options.fast_math}, optimize_module{}, - enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}), + layout_convolution{.channels_last = enabled(MIGRAPHX_ENABLE_NHWC{})}, dead_code_elimination{}, prefuse_ops{}, dead_code_elimination{}, eliminate_data_type{{migraphx::shape::fp8e4m3fnuz_type}, shape::float_type, unsupported_fp8e4m3fnuz_ops}, + eliminate_data_type{{migraphx::shape::fp8e5m2fnuz_type}, shape::float_type, unsupported_fp8e5m2fnuz_ops}, eliminate_data_type{{migraphx::shape::fp8e4m3fn_type, migraphx::shape::fp8e5m2_type}, shape::float_type, unsupported_fp8ocp_ops}, dead_code_elimination{}, rewrite_reduce{}, @@ -194,7 +211,6 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, optimize_module{}, fuse_pointwise_reduce{}, - enable_pass(enabled(MIGRAPHX_ENABLE_SPLIT_REDUCE{}), split_reduce{}), dead_code_elimination{}, #ifndef _WIN32 enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}), @@ -215,9 +231,12 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti compile_miopen{&gctx}, dead_code_elimination{}, #endif - dead_code_elimination{}, fuse_ops{&ctx, options.fast_math}, dead_code_elimination{}, +#if MIGRAPHX_USE_HIPBLASLT + compile_hipblaslt{&gctx}, + dead_code_elimination{}, +#endif replace_allocate{gpu_allocation_model{}, options.offload_copy}, dead_code_elimination{}, adjust_allocation{gpu_allocation_model{}}, diff --git a/src/quantize_fp16.cpp b/src/truncate_float.cpp similarity index 90% rename from src/quantize_fp16.cpp rename to src/truncate_float.cpp index 2e7e9f00a9e..15f807684d3 100644 --- a/src/quantize_fp16.cpp +++ b/src/truncate_float.cpp @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include #include @@ -35,7 +35,8 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -static void quantize_module(module& m, const std::vector& ins_names) +static void +quantize_module(module& m, const std::vector& ins_names, shape::type_t float_type) { for(auto ins : iterator_for(m)) { @@ -52,14 +53,14 @@ static void quantize_module(module& m, const std::vector& ins_names auto mod_inputs = ins->module_inputs(); auto s = ins->get_shape(); - // Convert each of the inputs that are floating point to fp16 + // Convert each of the inputs that are floating point to float type auto inputs = ins->inputs(); std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { auto input_type = input->get_shape().type(); if(input_type != shape::float_type and input_type != shape::double_type) return input; return m.insert_instruction( - ins, make_op("convert", {{"target_type", shape::half_type}}), input); + ins, make_op("convert", {{"target_type", float_type}}), input); }); // Insert quantized ins @@ -71,13 +72,13 @@ static void quantize_module(module& m, const std::vector& ins_names auto outputs = ins->outputs(); std::transform( outputs.begin(), outputs.end(), outputs.begin(), [&](const auto gte_ins) { - auto gte_ins_half = + auto gte_ins_float_type = m.insert_instruction(ins, gte_ins->get_operator(), converted_ins); // Convert back to output type after quantizing auto gte_converted = m.insert_instruction( ins, make_op("convert", {{"target_type", gte_ins->get_shape().type()}}), - gte_ins_half); + gte_ins_float_type); // Replace output instruction return m.replace_instruction(gte_ins, gte_converted); }); @@ -96,7 +97,7 @@ static void quantize_module(module& m, const std::vector& ins_names } } -void quantize_fp16_pass::apply(module& m) const { quantize_module(m, ins_names); } +void truncate_float_pass::apply(module& m) const { quantize_module(m, ins_names, float_type); } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/test/autocast_fp8.cpp b/test/autocast_fp8.cpp index 0ee20723b3f..c9a46ecd39c 100644 --- a/test/autocast_fp8.cpp +++ b/test/autocast_fp8.cpp @@ -63,6 +63,7 @@ void autocast_fp8_1() EXPECT(m1 == m2); } TEST_CASE_REGISTER(autocast_fp8_1); +TEST_CASE_REGISTER(autocast_fp8_1); TEST_CASE_REGISTER(autocast_fp8_1); TEST_CASE_REGISTER(autocast_fp8_1); @@ -91,6 +92,7 @@ void autocast_fp8_2() EXPECT(m1 == m2); } TEST_CASE_REGISTER(autocast_fp8_2); +TEST_CASE_REGISTER(autocast_fp8_2); TEST_CASE_REGISTER(autocast_fp8_2); TEST_CASE_REGISTER(autocast_fp8_2); @@ -127,6 +129,7 @@ void autocast_fp8_3() EXPECT(m1 == m2); } TEST_CASE_REGISTER(autocast_fp8_3); +TEST_CASE_REGISTER(autocast_fp8_3); TEST_CASE_REGISTER(autocast_fp8_3); TEST_CASE_REGISTER(autocast_fp8_3); @@ -166,6 +169,7 @@ void autocast_fp8_4() EXPECT(m1 == m2); } TEST_CASE_REGISTER(autocast_fp8_4); +TEST_CASE_REGISTER(autocast_fp8_4); TEST_CASE_REGISTER(autocast_fp8_4); TEST_CASE_REGISTER(autocast_fp8_4); diff --git a/test/base64_test.cpp b/test/base64_test.cpp new file mode 100644 index 00000000000..65b8c3a8800 --- /dev/null +++ b/test/base64_test.cpp @@ -0,0 +1,140 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include "test.hpp" + +TEST_CASE(base64_encoding) +{ + EXPECT(migraphx::base64_encode("abc") == "YWJj"); + + EXPECT(migraphx::base64_encode("abcd") == "YWJjZA=="); + + EXPECT(migraphx::base64_encode("convolution") == "Y29udm9sdXRpb24="); + + EXPECT(migraphx::base64_encode("https://www.amd.com/en/products/software/rocm.html") == + "aHR0cHM6Ly93d3cuYW1kLmNvbS9lbi9wcm9kdWN0cy9zb2Z0d2FyZS9yb2NtLmh0bWw="); + + EXPECT(migraphx::base64_encode("{1, 3, 7, 9}") == "ezEsIDMsIDcsIDl9"); +} + +TEST_CASE(base64_RFC_test_vectors) +{ + EXPECT(migraphx::base64_encode("") == ""); + + EXPECT(migraphx::base64_encode("f") == "Zg=="); + + EXPECT(migraphx::base64_encode("fo") == "Zm8="); + + EXPECT(migraphx::base64_encode("foo") == "Zm9v"); + + EXPECT(migraphx::base64_encode("foob") == "Zm9vYg=="); + + EXPECT(migraphx::base64_encode("fooba") == "Zm9vYmE="); + + EXPECT(migraphx::base64_encode("foobar") == "Zm9vYmFy"); +} + +// Following tests altered from +// https://github.com/tobiaslocker/base64/blob/master/test/base64_tests.cpp +TEST_CASE(base64_encodes_three_bytes_zeros) +{ + std::array const input{0x00, 0x00, 0x00}; + std::string expected{"AAAA"}; + std::string actual{migraphx::base64_encode({input.begin(), input.end()})}; + EXPECT(expected == actual); +} + +TEST_CASE(base64_encodes_three_bytes_random) +{ + std::array const input{0xFE, 0xE9, 0x72}; + std::string const expected{"/uly"}; + std::string const actual{migraphx::base64_encode({input.begin(), input.end()})}; + EXPECT(expected == actual); +} + +TEST_CASE(base64_encodes_two_bytes) +{ + std::array const input{0x00, 0x00}; + std::string expected{"AAA="}; + std::string actual{migraphx::base64_encode({input.begin(), input.end()})}; + EXPECT(expected == actual); +} + +TEST_CASE(base64_encodes_one_byte) +{ + std::array const input{0x00}; + std::string expected{"AA=="}; + std::string actual{migraphx::base64_encode({input.begin(), input.end()})}; + EXPECT(expected == actual); +} + +TEST_CASE(base64_encodes_four_bytes) +{ + std::array const input{0x74, 0x68, 0x65, 0x20}; + std::string expected{"dGhlIA=="}; + std::string actual{migraphx::base64_encode({input.begin(), input.end()})}; + EXPECT(expected == actual); +} + +TEST_CASE(base64_encodes_five_bytes) +{ + std::array const input{0x20, 0x62, 0x72, 0x6f, 0x77}; + std::string expected{"IGJyb3c="}; + std::string actual{migraphx::base64_encode({input.begin(), input.end()})}; + EXPECT(expected == actual); +} + +TEST_CASE(base64_encodes_six_bytes) +{ + std::array const input{0x20, 0x6a, 0x75, 0x6d, 0x70, 0x73}; + std::string expected{"IGp1bXBz"}; + std::string actual{migraphx::base64_encode({input.begin(), input.end()})}; + EXPECT(expected == actual); +} + +TEST_CASE(base64_encodes_BrownFox) +{ + std::array const input{ + 0x74, 0x68, 0x65, 0x20, 0x71, 0x75, 0x69, 0x63, 0x6b, 0x20, 0x62, 0x72, 0x6f, 0x77, 0x6e, + 0x20, 0x66, 0x6f, 0x78, 0x20, 0x6a, 0x75, 0x6d, 0x70, 0x73, 0x20, 0x6f, 0x76, 0x65, 0x72, + 0x20, 0x74, 0x68, 0x65, 0x20, 0x6c, 0x61, 0x7a, 0x79, 0x20, 0x64, 0x6f, 0x67}; + + std::string expected{"dGhlIHF1aWNrIGJyb3duIGZveCBqdW1wcyBvdmVyIHRoZSBsYXp5IGRvZw=="}; + std::string actual{migraphx::base64_encode({input.begin(), input.end()})}; + EXPECT(expected == actual); +} + +TEST_CASE(base64_encodes_EncodesBrownFastFoxNullInMiddle) +{ + std::array const input{ + 0x74, 0x68, 0x65, 0x20, 0x71, 0x75, 0x69, 0x63, 0x6b, 0x21, 0x20, 0x62, 0x72, 0x6f, 0x77, + 0x6e, 0x20, 0x66, 0x6f, 0x78, 0x20, 0x6a, 0x75, 0x6d, 0x70, 0x73, 0x20, 0x6f, 0x76, 0x65, + 0x72, 0x20, 0x74, 0x68, 0x65, 0x00, 0x20, 0x6c, 0x61, 0x7a, 0x79, 0x20, 0x64, 0x6f, 0x67}; + + std::string expected{"dGhlIHF1aWNrISBicm93biBmb3gganVtcHMgb3ZlciB0aGUAIGxhenkgZG9n"}; + std::string actual{migraphx::base64_encode({input.begin(), input.end()})}; + EXPECT(expected == actual); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/bf16.cpp b/test/bf16.cpp new file mode 100644 index 00000000000..9ba04d083a2 --- /dev/null +++ b/test/bf16.cpp @@ -0,0 +1,1245 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include "test.hpp" + +#include +#include +#include +#include +#include +#include + +template +bool bit_equal(const T& x, const U& y) +{ + static_assert(sizeof(T) == sizeof(U)); + using type = std::array; + return migraphx::bit_cast(x) == migraphx::bit_cast(y); +} + +TEST_CASE(check_numeric_limits) +{ + CHECK(bit_equal(std::numeric_limits::min(), uint16_t{0x0080})); + CHECK(bit_equal(std::numeric_limits::lowest(), uint16_t{0xff7f})); + CHECK(bit_equal(std::numeric_limits::max(), uint16_t{0x7f7f})); + CHECK(bit_equal(std::numeric_limits::epsilon(), uint16_t{0x3c00})); + CHECK(bit_equal(std::numeric_limits::denorm_min(), uint16_t{0x0001})); + CHECK(bit_equal(std::numeric_limits::infinity(), uint16_t{0x7f80})); + CHECK(bit_equal(std::numeric_limits::quiet_NaN(), uint16_t{0x7fc0})); + CHECK(bit_equal(std::numeric_limits::signaling_NaN(), uint16_t{0x7fa0})); +} + +const std::map& bf16_lut() // NOLINT(readability-function-size) +{ + static const std::map result = { + {0x0000, 0.0}, + {0x002d, 0.00000000000000000000000000000000000000413259732711}, + {0x004e, 0.00000000000000000000000000000000000000716316870032}, + {0x005b, 0.00000000000000000000000000000000000000835703015038}, + {0x00cd, 0.00000000000000000000000000000000000001882627671239}, + {0x00ce, 0.00000000000000000000000000000000000001891811220855}, + {0x0170, 0.00000000000000000000000000000000000004408103815584}, + {0x01be, 0.00000000000000000000000000000000000006979497708007}, + {0x01fe, 0.00000000000000000000000000000000000009330486409652}, + {0x0211, 0.00000000000000000000000000000000000010652917554327}, + {0x028f, 0.00000000000000000000000000000000000021011961520948}, + {0x02bc, 0.00000000000000000000000000000000000027624117244324}, + {0x02e8, 0.00000000000000000000000000000000000034089336173846}, + {0x039c, 0.00000000000000000000000000000000000091688559364138}, + {0x03a9, 0.00000000000000000000000000000000000099329272644483}, + {0x03cf, 0.00000000000000000000000000000000000121663665310107}, + {0x03da, 0.00000000000000000000000000000000000128128884239629}, + {0x03f3, 0.00000000000000000000000000000000000142822563624908}, + {0x0427, 0.00000000000000000000000000000000000196307556587322}, + {0x044c, 0.00000000000000000000000000000000000239800847567747}, + {0x0485, 0.00000000000000000000000000000000000312681497318728}, + {0x0498, 0.00000000000000000000000000000000000357350282649975}, + {0x04da, 0.00000000000000000000000000000000000512515536958517}, + {0x051b, 0.00000000000000000000000000000000000728806497509818}, + {0x0533, 0.00000000000000000000000000000000000841653955188758}, + {0x0536, 0.00000000000000000000000000000000000855759887398625}, + {0x0577, 0.00000000000000000000000000000000001161388418612420}, + {0x0587, 0.00000000000000000000000000000000001269533898888071}, + {0x065a, 0.00000000000000000000000000000000004100124295668139}, + {0x0735, 0.00000000000000000000000000000000013616926559925378}, + {0x075e, 0.00000000000000000000000000000000016701423736483061}, + {0x0765, 0.00000000000000000000000000000000017228045205651446}, + {0x076b, 0.00000000000000000000000000000000017679435036367204}, + {0x07ba, 0.00000000000000000000000000000000027986169504377021}, + {0x07d1, 0.00000000000000000000000000000000031446824873197835}, + {0x07db, 0.00000000000000000000000000000000032951457642250363}, + {0x07eb, 0.00000000000000000000000000000000035358870072734408}, + {0x080d, 0.00000000000000000000000000000000042430644087281290}, + {0x0841, 0.00000000000000000000000000000000058078824885427581}, + {0x08ff, 0.00000000000000000000000000000000153472542443357857}, + {0x09e1, 0.00000000000000000000000000000000541667796858910084}, + {0x0a0b, 0.00000000000000000000000000000000669260655674564459}, + {0x0a0f, 0.00000000000000000000000000000000688519955118436817}, + {0x0a12, 0.00000000000000000000000000000000702964429701341086}, + {0x0a1d, 0.00000000000000000000000000000000755927503171990072}, + {0x0aa2, 0.00000000000000000000000000000001560003254953661041}, + {0x0aaf, 0.00000000000000000000000000000001685188701338831371}, + {0x0ab4, 0.00000000000000000000000000000001733336949948512268}, + {0x0ade, 0.00000000000000000000000000000002137782238269831797}, + {0x0b6e, 0.00000000000000000000000000000004583713267641621330}, + {0x0bb6, 0.00000000000000000000000000000007010384997569538505}, + {0x0c8c, 0.00000000000000000000000000000021570415377137041554}, + {0x0cf7, 0.00000000000000000000000000000038056375701091780456}, + {0x0d06, 0.00000000000000000000000000000041291938007662336690}, + {0x0d20, 0.00000000000000000000000000000049303806576313237838}, + {0x0d5e, 0.00000000000000000000000000000068409031624634617501}, + {0x0d60, 0.00000000000000000000000000000069025329206838532974}, + {0x0d6a, 0.00000000000000000000000000000072106817117858110338}, + {0x0d76, 0.00000000000000000000000000000075804602611081603176}, + {0x0d7e, 0.00000000000000000000000000000078269792939897265068}, + {0x0dcc, 0.00000000000000000000000000000125724706769598756487}, + {0x0dfc, 0.00000000000000000000000000000155306990715386699190}, + {0x0e09, 0.00000000000000000000000000000168865537523872839596}, + {0x0e3c, 0.00000000000000000000000000000231727890908672217840}, + {0x0e69, 0.00000000000000000000000000000287194673307024610408}, + {0x0f42, 0.00000000000000000000000000000956493847580476814062}, + {0x0f8f, 0.00000000000000000000000000001410088868082558602173}, + {0x0fa2, 0.00000000000000000000000000001597443333072548905959}, + {0x1007, 0.00000000000000000000000000002662405555120914843265}, + {0x1015, 0.00000000000000000000000000002938506871948268975159}, + {0x1030, 0.00000000000000000000000000003470987982972451943812}, + {0x1042, 0.00000000000000000000000000003825975390321907256247}, + {0x1056, 0.00000000000000000000000000004220405842932413158953}, + {0x1166, 0.00000000000000000000000000018143800820083271524470}, + {0x1182, 0.00000000000000000000000000020510383535746306940705}, + {0x128a, 0.00000000000000000000000000087090243936399703317455}, + {0x1295, 0.00000000000000000000000000094032219902344607205078}, + {0x129e, 0.00000000000000000000000000099712018419935892204042}, + {0x12b0, 0.00000000000000000000000000111071615455118462201971}, + {0x12bf, 0.00000000000000000000000000120537946317770603866912}, + {0x1315, 0.00000000000000000000000000188064439804689214410156}, + {0x1343, 0.00000000000000000000000000246124602428955683288459}, + {0x13e2, 0.00000000000000000000000000570504206655835737673762}, + {0x13ee, 0.00000000000000000000000000600796465416322591001572}, + {0x1445, 0.00000000000000000000000000994595829302651684263107}, + {0x1469, 0.00000000000000000000000001176349381865572804229970}, + {0x14bd, 0.00000000000000000000000001908412301910671759652054}, + {0x1541, 0.00000000000000000000000003897603960515975128178268}, + {0x1546, 0.00000000000000000000000003998578156384264639270970}, + {0x1569, 0.00000000000000000000000004705397527462291216919879}, + {0x16a1, 0.00000000000000000000000026010952855671378057479844}, + {0x16b6, 0.00000000000000000000000029403685836845905630194606}, + {0x16f4, 0.00000000000000000000000039420326066980225130590570}, + {0x1703, 0.00000000000000000000000042328382907986963050060367}, + {0x1720, 0.00000000000000000000000051698788284564229679463043}, + {0x178e, 0.00000000000000000000000091765349205101507681046902}, + {0x17f6, 0.00000000000000000000000158973773975035006264348858}, + {0x18e8, 0.00000000000000000000000599705944100945064281771302}, + {0x18f1, 0.00000000000000000000000622970398828998967637529671}, + {0x1927, 0.00000000000000000000000863369764352222635647032822}, + {0x1a39, 0.00000000000000000000003825710333057752996280265201}, + {0x1a60, 0.00000000000000000000004632211430296954979279888676}, + {0x1a69, 0.00000000000000000000004818327068121386206125955631}, + {0x1b1f, 0.00000000000000000000013152171739593140030455398204}, + {0x1b43, 0.00000000000000000000016130021944784039659992469495}, + {0x1ba0, 0.00000000000000000000026469779601696885595885078146}, + {0x1bc0, 0.00000000000000000000031763735522036262715062093775}, + {0x1bea, 0.00000000000000000000038712052667481695183981926789}, + {0x1c25, 0.00000000000000000000054593920428499826541512973677}, + {0x1c32, 0.00000000000000000000058895259613775570450844298875}, + {0x1cbe, 0.00000000000000000000125731453108060206580454121195}, + {0x1ccd, 0.00000000000000000000135657620458696538678911025499}, + {0x1cd6, 0.00000000000000000000141613320869078337937985168082}, + {0x1d23, 0.00000000000000000000215728703753829617606463386892}, + {0x1d38, 0.00000000000000000000243521972335611347482142718945}, + {0x1dce, 0.00000000000000000000545277459794955843275232609813}, + {0x1e7c, 0.00000000000000000001334076891925523034032607938570}, + {0x1e9e, 0.00000000000000000001672890070827243169659936938842}, + {0x1eb1, 0.00000000000000000001874060395800139500188663532754}, + {0x1f33, 0.00000000000000000003790472438962994017330743190541}, + {0x1f56, 0.00000000000000000004531626267810506814015525378636}, + {0x1fa8, 0.00000000000000000007115076756936122848173909005709}, + {0x1fed, 0.00000000000000000010037340424963459017959621633054}, + {0x2001, 0.00000000000000000010926725019580474373981360258767}, + {0x213f, 0.00000000000000000064713317170228545904819839051925}, + {0x2154, 0.00000000000000000071828393927164668752993748057634}, + {0x216b, 0.00000000000000000079621097041904231872422315063886}, + {0x219f, 0.00000000000000000107742590890747003129490622086450}, + {0x2225, 0.00000000000000000223616698075135289514037140179425}, + {0x2229, 0.00000000000000000229037708937562811684074404183775}, + {0x2274, 0.00000000000000000330681662608078852372273104265332}, + {0x227b, 0.00000000000000000340168431617327016169838316272944}, + {0x2294, 0.00000000000000000401154803819636640582757536321878}, + {0x22b7, 0.00000000000000000496022493912118278558409656397998}, + {0x2379, 0.00000000000000001349831704744453020339278737083077}, + {0x2389, 0.00000000000000001485356976305141074590210337191820}, + {0x2464, 0.00000000000000004943961906533900219073984771966934}, + {0x2475, 0.00000000000000005312590645178971726636518724262714}, + {0x2491, 0.00000000000000006288372600415925717243226245045662}, + {0x24f6, 0.00000000000000010668549377257363630633335560560226}, + {0x2567, 0.00000000000000020036056147532121940457727760076523}, + {0x256a, 0.00000000000000020296264668928643004619516432285309}, + {0x257f, 0.00000000000000022117724318704290453752037137746811}, + {0x25f9, 0.00000000000000043194614551822496650856919586658478}, + {0x263b, 0.00000000000000064878658001532585331005975604057312}, + {0x270c, 0.00000000000000194289029309402394574135541915893555}, + {0x272c, 0.00000000000000238697950294408656191080808639526367}, + {0x2745, 0.00000000000000273392419813944798079319298267364502}, + {0x2791, 0.00000000000000402455846426619245903566479682922363}, + {0x2792, 0.00000000000000405231403988182137254625558853149414}, + {0x27a0, 0.00000000000000444089209850062616169452667236328125}, + {0x27a3, 0.00000000000000452415882534751290222629904747009277}, + {0x27d4, 0.00000000000000588418203051332966424524784088134766}, + {0x27dd, 0.00000000000000613398221105398988584056496620178223}, + {0x2821, 0.00000000000000893729534823251015041023492813110352}, + {0x28d9, 0.00000000000002409183963436589692719280719757080078}, + {0x2981, 0.00000000000005728750807065807748585939407348632812}, + {0x29ca, 0.00000000000008970602038971264846622943878173828125}, + {0x2a5b, 0.00000000000019451107391432742588222026824951171875}, + {0x2aa8, 0.00000000000029842794901924207806587219238281250}, + {0x2ac4, 0.000000000000348165940522449091076850891113281250}, + {0x2ae7, 0.00000000000041033842990145785734057426452636718750}, + {0x2af1, 0.00000000000042810199829546036198735237121582031250}, + {0x2afe, 0.0000000000004511946372076636180281639099121093750}, + {0x2b24, 0.00000000000058264504332328215241432189941406250}, + {0x2b4c, 0.00000000000072475359047530218958854675292968750}, + {0x2b85, 0.000000000000945021838560933247208595275878906250}, + {0x2bed, 0.000000000001683986283751437440514564514160156250}, + {0x2c18, 0.00000000000216004991671070456504821777343750}, + {0x2cdf, 0.0000000000063380412029800936579704284667968750}, + {0x2d6b, 0.000000000013358203432289883494377136230468750}, + {0x2d96, 0.0000000000170530256582424044609069824218750}, + {0x2da2, 0.0000000000184172677109017968177795410156250}, + {0x2db8, 0.00000000002091837814077734947204589843750}, + {0x2de1, 0.00000000002557953848736360669136047363281250}, + {0x2e90, 0.00000000006548361852765083312988281250}, + {0x2ea3, 0.000000000074123818194493651390075683593750}, + {0x2ef0, 0.00000000010913936421275138854980468750}, + {0x2f09, 0.00000000012460077414289116859436035156250}, + {0x2f6b, 0.00000000021373125491663813591003417968750}, + {0x303b, 0.000000000680302036926150321960449218750}, + {0x308f, 0.00000000104046193882822990417480468750}, + {0x309c, 0.000000001135049387812614440917968750}, + {0x30a9, 0.00000000122963683679699897766113281250}, + {0x312a, 0.000000002473825588822364807128906250}, + {0x313d, 0.0000000027503119781613349914550781250}, + {0x3159, 0.0000000031577656045556068420410156250}, + {0x31c6, 0.00000000576255843043327331542968750}, + {0x3212, 0.0000000084983184933662414550781250}, + {0x3245, 0.00000001146690919995307922363281250}, + {0x329b, 0.0000000180443748831748962402343750}, + {0x32ba, 0.000000021653249859809875488281250}, + {0x32cc, 0.00000002374872565269470214843750}, + {0x3332, 0.00000004144385457038879394531250}, + {0x33c4, 0.000000091269612312316894531250}, + {0x3424, 0.00000015273690223693847656250}, + {0x3589, 0.0000010207295417785644531250}, + {0x3594, 0.00000110268592834472656250}, + {0x368b, 0.00000414252281188964843750}, + {0x36a0, 0.000004768371582031250}, + {0x36e9, 0.00000694394111633300781250}, + {0x36ed, 0.00000706315040588378906250}, + {0x3750, 0.000012397766113281250}, + {0x375f, 0.0000132918357849121093750}, + {0x37ce, 0.00002455711364746093750}, + {0x37d2, 0.00002503395080566406250}, + {0x37f0, 0.00002861022949218750}, + {0x380e, 0.0000338554382324218750}, + {0x3826, 0.0000395774841308593750}, + {0x387f, 0.00006079673767089843750}, + {0x38e1, 0.0001072883605957031250}, + {0x38e9, 0.0001111030578613281250}, + {0x3964, 0.0002174377441406250}, + {0x3994, 0.000282287597656250}, + {0x3a26, 0.000633239746093750}, + {0x3a2c, 0.00065612792968750}, + {0x3a6a, 0.000892639160156250}, + {0x3a85, 0.001014709472656250}, + {0x3ab9, 0.001411437988281250}, + {0x3aba, 0.00141906738281250}, + {0x3af7, 0.001884460449218750}, + {0x3b03, 0.00199890136718750}, + {0x3bb3, 0.0054626464843750}, + {0x3bbf, 0.0058288574218750}, + {0x3be8, 0.0070800781250}, + {0x3c06, 0.00817871093750}, + {0x3c29, 0.010314941406250}, + {0x3c3f, 0.011657714843750}, + {0x3c73, 0.014831542968750}, + {0x3ce9, 0.02844238281250}, + {0x3cfa, 0.0305175781250}, + {0x3cfb, 0.03063964843750}, + {0x3d0f, 0.0349121093750}, + {0x3d2a, 0.041503906250}, + {0x3d43, 0.0476074218750}, + {0x3dd0, 0.10156250}, + {0x3dd9, 0.105957031250}, + {0x3de9, 0.113769531250}, + {0x3df9, 0.121582031250}, + {0x3e1e, 0.1542968750}, + {0x3e77, 0.24121093750}, + {0x3e95, 0.2910156250}, + {0x3f38, 0.718750}, + {0x3fb3, 1.39843750}, + {0x3fc5, 1.53906250}, + {0x3fd3, 1.64843750}, + {0x3fd7, 1.67968750}, + {0x400b, 2.1718750}, + {0x40bf, 5.968750}, + {0x40c7, 6.218750}, + {0x4123, 10.18750}, + {0x412b, 10.68750}, + {0x41bf, 23.8750}, + {0x41ca, 25.250}, + {0x421b, 38.750}, + {0x4226, 41.50}, + {0x42a7, 83.50}, + {0x42b7, 91.50}, + {0x4311, 145.0}, + {0x431f, 159.0}, + {0x4334, 180.0}, + {0x434f, 207.0}, + {0x43b1, 354.0}, + {0x43e5, 458.0}, + {0x4476, 984.0}, + {0x4496, 1200.0}, + {0x44a4, 1312.0}, + {0x458b, 4448.0}, + {0x45a9, 5408.0}, + {0x45df, 7136.0}, + {0x45f6, 7872.0}, + {0x45fa, 8000.0}, + {0x4602, 8320.0}, + {0x4640, 12288.0}, + {0x4648, 12800.0}, + {0x46a7, 21376.0}, + {0x46b1, 22656.0}, + {0x4742, 49664.0}, + {0x4744, 50176.0}, + {0x475e, 56832.0}, + {0x477a, 64000.0}, + {0x4837, 187392.0}, + {0x488a, 282624.0}, + {0x488f, 292864.0}, + {0x48ea, 479232.0}, + {0x495c, 901120.0}, + {0x49aa, 1392640.0}, + {0x49b9, 1515520.0}, + {0x4a1e, 2588672.0}, + {0x4a2b, 2801664.0}, + {0x4a4c, 3342336.0}, + {0x4ab6, 5963776.0}, + {0x4b34, 11796480.0}, + {0x4b73, 15925248.0}, + {0x4b7b, 16449536.0}, + {0x4bcd, 26869760.0}, + {0x4bd0, 27262976.0}, + {0x4c07, 35389440.0}, + {0x4c17, 39583744.0}, + {0x4c53, 55312384.0}, + {0x4cad, 90701824.0}, + {0x4d1c, 163577856.0}, + {0x4dc0, 402653184.0}, + {0x4dde, 465567744.0}, + {0x4eef, 2004877312.0}, + {0x4efc, 2113929216.0}, + {0x4f12, 2449473536.0}, + {0x4f2f, 2936012800.0}, + {0x4f92, 4898947072.0}, + {0x4fad, 5804916736.0}, + {0x4fdc, 7381975040.0}, + {0x4feb, 7885291520.0}, + {0x5076, 16508780544.0}, + {0x5083, 17582522368.0}, + {0x5215, 159987531776.0}, + {0x52a9, 362924736512.0}, + {0x5394, 1271310319616.0}, + {0x53a0, 1374389534720.0}, + {0x53b7, 1571958030336.0}, + {0x540e, 2439541424128.0}, + {0x542f, 3006477107200.0}, + {0x5465, 3934190043136.0}, + {0x5529, 11613591568384.0}, + {0x554c, 14018773254144.0}, + {0x5596, 20615843020800.0}, + {0x55ae, 23914377904128.0}, + {0x55be, 26113401159680.0}, + {0x55da, 29961691856896.0}, + {0x568d, 77515569758208.0}, + {0x5690, 79164837199872.0}, + {0x56ad, 95107755802624.0}, + {0x5718, 167125767421952.0}, + {0x571c, 171523813933056.0}, + {0x571d, 172623325560832.0}, + {0x5826, 730075720843264.0}, + {0x587c, 1108307720798208.0}, + {0x5890, 1266637395197952.0}, + {0x58a8, 1477743627730944.0}, + {0x58f6, 2163838883463168.0}, + {0x5966, 4046202790215680.0}, + {0x5985, 4679521487814656.0}, + {0x59ad, 6086896371367936.0}, + {0x59b0, 6192449487634432.0}, + {0x59bc, 6614661952700416.0}, + {0x5a93, 20688410788233216.0}, + {0x5ab0, 24769797950537728.0}, + {0x5ab6, 25614222880669696.0}, + {0x5ae3, 31947409856659456.0}, + {0x5af5, 34480684647055360.0}, + {0x5afd, 35606584553897984.0}, + {0x5bb6, 102456891522678784.0}, + {0x5c3d, 212795082393255936.0}, + {0x5d57, 968273919884656640.0}, + {0x5d76, 1107885508333142016.0}, + {0x5d86, 1206964700135292928.0}, + {0x5db2, 1603281467343896576.0}, + {0x5dc1, 1738389456165011456.0}, + {0x5dc5, 1774418253183975424.0}, + {0x5e5d, 3981182070595518464.0}, + {0x5fa4, 23634890844440363008.0}, + {0x5fa8, 24211351596743786496.0}, + {0x5ff8, 35740566642812256256.0}, + {0x6006, 38622870404329373696.0}, + {0x6051, 60240148615707754496.0}, + {0x60ed, 136621198295911366656.0}, + {0x610b, 160256089140351729664.0}, + {0x6114, 170632382681813352448.0}, + {0x613b, 215596321361480384512.0}, + {0x6148, 230584300921369395200.0}, + {0x61ad, 398910840593969053696.0}, + {0x61fd, 583378281331064569856.0}, + {0x629f, 1466516153859909353472.0}, + {0x62a4, 1512633014044183232512.0}, + {0x62fb, 2315066381250548727808.0}, + {0x634b, 3744689046963038978048.0}, + {0x635a, 4021390208068682252288.0}, + {0x635f, 4113623928437230010368.0}, + {0x637c, 4648579506574807007232.0}, + {0x638d, 5201981828786093555712.0}, + {0x63a9, 6234999496913828446208.0}, + {0x6469, 17192365476697302106112.0}, + {0x64c0, 28334198897217871282176.0}, + {0x64d1, 30842956091242370301952.0}, + {0x64dd, 32613843522318487257088.0}, + {0x64ee, 35122600716342986276864.0}, + {0x64ef, 35270174668932662689792.0}, + {0x64fb, 37041062100008779644928.0}, + {0x6510, 42501298345826806923264.0}, + {0x6581, 76148159536273029070848.0}, + {0x65d4, 125142711796045598162944.0}, + {0x6612, 172366376624742050299904.0}, + {0x661c, 184172292831916163334144.0}, + {0x66c6, 467514281804094876155904.0}, + {0x66ed, 559600428220052957822976.0}, + {0x66f9, 587934627117270829105152.0}, + {0x6703, 618630009255923522994176.0}, + {0x6752, 991696961402625494876160.0}, + {0x6797, 1426154677826632854536192.0}, + {0x679c, 1473378342655329306673152.0}, + {0x67e3, 2143954383222818927017984.0}, + {0x6862, 4269019300514159273181184.0}, + {0x692f, 13222626152035006598348800.0}, + {0x693d, 14280436244197807126216704.0}, + {0x6943, 14733783426553293066731520.0}, + {0x695e, 16773845747152979799048192.0}, + {0x69ec, 35663311678631560653832192.0}, + {0x69f4, 36872237498246189828538368.0}, + {0x6a5c, 66490920078804604608839680.0}, + {0x6a7a, 75557863725914323419136000.0}, + {0x6ac3, 117870267412426344533852160.0}, + {0x6ad8, 130563988518379950868267008.0}, + {0x6ae3, 137213080526260411329150976.0}, + {0x6b37, 221233424989477138971230208.0}, + {0x6b77, 298604677444813406152425472.0}, + {0x6bd7, 519838102434290545123655680.0}, + {0x6be7, 558523728661958678714253312.0}, + {0x6c15, 720519788490318988124880896.0}, + {0x6c33, 865590886844074489089622016.0}, + {0x6c43, 942962139299410756270817280.0}, + {0x6c45, 952633545856327789668466688.0}, + {0x6c48, 967140655691703339764940800.0}, + {0x6da6, 6421813953792910176039206912.0}, + {0x6dba, 7195526478346272847851159552.0}, + {0x6def, 9245864668412683928152834048.0}, + {0x6e24, 12688885402675147817716023296.0}, + {0x6e28, 12998370412496492886440804352.0}, + {0x6e3d, 14623166714058554497245904896.0}, + {0x6ea0, 24758800785707605497982484480.0}, + {0x6eef, 36983458673650735712611336192.0}, + {0x6ef9, 38530883722757461056235241472.0}, + {0x6f55, 65920307091946499638378364928.0}, + {0x6f5c, 68086702160695915119451832320.0}, + {0x6f65, 70872067249088020737974861824.0}, + {0x6f9e, 97797263103545041717030813696.0}, + {0x6fdc, 136173404321391830238903664640.0}, + {0x70ab, 423375493435600054015500484608.0}, + {0x70dc, 544693617285567320955614658560.0}, + {0x714a, 1000255551742587262118492372992.0}, + {0x71ba, 1842054778456645849049896845312.0}, + {0x71d3, 2089642786313721904029721690112.0}, + {0x71ee, 2357037834799364043407932522496.0}, + {0x71f6, 2436265997313628381001476472832.0}, + {0x7251, 4139671491370311639262671405056.0}, + {0x72b8, 7288990951312319058606043430912.0}, + {0x731f, 12597277839768029677373488103424.0}, + {0x7328, 13310331302396408715715383656448.0}, + {0x7356, 16954826778052568245018405371904.0}, + {0x7358, 17113283103081096920205493272576.0}, + {0x7375, 19410899815994762710418267832320.0}, + {0x737c, 19965496953594613073573075484672.0}, + {0x7467, 73206822163180247936434610110464.0}, + {0x74a4, 103947349218714810922729662840832.0}, + {0x74b1, 112187078120198302032458233675776.0}, + {0x7530, 223106505640168374663419764146176.0}, + {0x7563, 287756686251808074139751627620352.0}, + {0x756f, 302968493454546826957712066084864.0}, + {0x7571, 305503794655003285760705472495616.0}, + {0x7626, 841719998551544322593810928369664.0}, + {0x7660, 1135814937804493543741046072016896.0}, + {0x7675, 1242297588223664813466769141268480.0}, + {0x76b2, 1805134454724998667731305364455424.0}, + {0x772c, 3488574451828087312918927221194752.0}, + {0x77b2, 7220537818899994670925221457821696.0}, + {0x783a, 15090112745116842795416754956795904.0}, + {0x795d, 71718600358512306619077480547352576.0}, + {0x796e, 77235415770705560974391132897148928.0}, + {0x798c, 90865195024359483499283685761351680.0}, + {0x79af, 113581493780449354374104607201689600.0}, + {0x7aa5, 428364490829123279353765947160657920.0}, + {0x7acf, 537402724858354659552906370074279936.0}, + {0x7adc, 571152654438831039138354596214210560.0}, + {0x7b07, 700960075902201729851617004444712960.0}, + {0x7b0f, 742498450770480350879860975078473728.0}, + {0x7b12, 758075341346084833765452464066134016.0}, + {0x7b30, 913844247102129662621367353942736896.0}, + {0x7bc2, 2014611181111513119869832575737397248.0}, + {0x7bd1, 2170380086867557948725747465614000128.0}, + {0x7bd5, 2211918461735836569753991436247760896.0}, + {0x7cf1, 10010748343255147667806796922736345088.0}, + {0x7d2f, 14538431203897517359885389721816268800.0}, + {0x7da0, 26584559915698317458076141205606891520.0}, + {0x7e58, 71778311772385457136805581255138607104.0}, + {0x7e81, 85735205728127073802295555388082225152.0}, + {0x7f09, 182104235422533474587821567258407206912.0}, + {0x7f24, 217993391308726203156224357885976510464.0}, + {0x7f86, std::numeric_limits::quiet_NaN()}, + {0x7f88, std::numeric_limits::quiet_NaN()}, + {0x7f8f, std::numeric_limits::quiet_NaN()}, + {0x7fa0, std::numeric_limits::quiet_NaN()}, + {0x7fcd, std::numeric_limits::quiet_NaN()}, + {0x8023, -0.00000000000000000000000000000000000000321424236553}, + {0x8074, -0.00000000000000000000000000000000000001065291755433}, + {0x8080, -0.00000000000000000000000000000000000001175494350822}, + {0x80a5, -0.00000000000000000000000000000000000001515285686607}, + {0x80d2, -0.00000000000000000000000000000000000001928545419318}, + {0x80fd, -0.00000000000000000000000000000000000002323438052797}, + {0x810a, -0.00000000000000000000000000000000000002534659693961}, + {0x8124, -0.00000000000000000000000000000000000003012204273982}, + {0x81e3, -0.00000000000000000000000000000000000008338663051146}, + {0x81f1, -0.00000000000000000000000000000000000008852941829630}, + {0x8285, -0.00000000000000000000000000000000000019542593582421}, + {0x828b, -0.00000000000000000000000000000000000020424214345537}, + {0x829f, -0.00000000000000000000000000000000000023362950222593}, + {0x82bc, -0.00000000000000000000000000000000000027624117244324}, + {0x82ee, -0.00000000000000000000000000000000000034970956936963}, + {0x82f9, -0.00000000000000000000000000000000000036587261669344}, + {0x8393, -0.00000000000000000000000000000000000086398834785438}, + {0x8394, -0.00000000000000000000000000000000000086986581960849}, + {0x843e, -0.00000000000000000000000000000000000223343926656235}, + {0x8451, -0.00000000000000000000000000000000000245678319321858}, + {0x847f, -0.00000000000000000000000000000000000299751059459683}, + {0x8483, -0.00000000000000000000000000000000000307979519915439}, + {0x84a0, -0.00000000000000000000000000000000000376158192263132}, + {0x84aa, -0.00000000000000000000000000000000000399668079279578}, + {0x84ba, -0.00000000000000000000000000000000000437283898505891}, + {0x84e4, -0.00000000000000000000000000000000000536025423974963}, + {0x8510, -0.00000000000000000000000000000000000677084746073638}, + {0x854d, -0.00000000000000000000000000000000000963905367674276}, + {0x8557, -0.00000000000000000000000000000000001010925141707167}, + {0x8584, -0.00000000000000000000000000000000001241322034468336}, + {0x85b7, -0.00000000000000000000000000000000001720923729603829}, + {0x8656, -0.00000000000000000000000000000000004024892657215512}, + {0x87c7, -0.00000000000000000000000000000000029942192104145307}, + {0x88d9, -0.00000000000000000000000000000000130602124353759431}, + {0x896c, -0.00000000000000000000000000000000284074666797117288}, + {0x89a0, -0.00000000000000000000000000000000385185988877447171}, + {0x89b2, -0.00000000000000000000000000000000428519412626159977}, + {0x89fe, -0.00000000000000000000000000000000611482757342947383}, + {0x8a31, -0.00000000000000000000000000000000852224000391351865}, + {0x8a55, -0.00000000000000000000000000000001025557695386203092}, + {0x8a68, -0.00000000000000000000000000000001117039367744596795}, + {0x8b4a, -0.00000000000000000000000000000003890378487662216423}, + {0x8b4c, -0.00000000000000000000000000000003928897086549961140}, + {0x8b5b, -0.00000000000000000000000000000004217786578208046518}, + {0x8bbf, -0.00000000000000000000000000000007357052387559240959}, + {0x8bff, -0.00000000000000000000000000000009822242716374902851}, + {0x8c43, -0.00000000000000000000000000000015022253566220439654}, + {0x8c6b, -0.00000000000000000000000000000018103741477240017019}, + {0x8d1d, -0.00000000000000000000000000000048379360203007364629}, + {0x8d23, -0.00000000000000000000000000000050228252949619111048}, + {0x8d30, -0.00000000000000000000000000000054234187233944561622}, + {0x8dee, -0.00000000000000000000000000000146678824564531882569}, + {0x8e54, -0.00000000000000000000000000000261310174854460160543}, + {0x8e68, -0.00000000000000000000000000000285962078142616779462}, + {0x8eb0, -0.00000000000000000000000000000433873497871556492976}, + {0x8efe, -0.00000000000000000000000000000626158343519178120546}, + {0x8f3d, -0.00000000000000000000000000000931841944292320195143}, + {0x8f95, -0.00000000000000000000000000001469253435974134487579}, + {0x8fa3, -0.00000000000000000000000000001607304094387811553526}, + {0x8fbe, -0.00000000000000000000000000001873544649899903037853}, + {0x9009, -0.00000000000000000000000000002701848600381965433535}, + {0x9017, -0.00000000000000000000000000002977949917209319565429}, + {0x90a8, -0.00000000000000000000000000006626431603856499165459}, + {0x90d2, -0.00000000000000000000000000008283039504820623956823}, + {0x9260, -0.00000000000000000000000000070681937107802657764891}, + {0x92d0, -0.00000000000000000000000000131266454628776364420512}, + {0x92ed, -0.00000000000000000000000000149568027629903838306064}, + {0x9356, -0.00000000000000000000000000270105973947674442172976}, + {0x94af, -0.00000000000000000000000001767048427695066444122272}, + {0x9503, -0.00000000000000000000000002645523931749185190628773}, + {0x953c, -0.00000000000000000000000003796629764647685617085567}, + {0x9556, -0.00000000000000000000000004321695583162791074767614}, + {0x9598, -0.00000000000000000000000006139231108792002274436236}, + {0x9599, -0.00000000000000000000000006179620787139318078873317}, + {0x95cf, -0.00000000000000000000000008360663417894371518475664}, + {0x95df, -0.00000000000000000000000009006898271451424389468952}, + {0x95f4, -0.00000000000000000000000009855081516745056282647643}, + {0x960f, -0.00000000000000000000000011551448007332320069005024}, + {0x9615, -0.00000000000000000000000012036124147500109722249990}, + {0x966a, -0.00000000000000000000000018902369466543796476553675}, + {0x968c, -0.00000000000000000000000022618219874496850484765081}, + {0x96b3, -0.00000000000000000000000028919009696678115976949640}, + {0x96d7, -0.00000000000000000000000034735123378691591815889232}, + {0x9719, -0.00000000000000000000000049436966297114544630986535}, + {0x9761, -0.00000000000000000000000072701421025168447986744905}, + {0x97f0, -0.00000000000000000000000155096364853692689038389130}, + {0x97fd, -0.00000000000000000000000163497417949934376361301874}, + {0x983c, -0.00000000000000000000000242984304937451879493476303}, + {0x987d, -0.00000000000000000000000326994835899868752722603749}, + {0x9895, -0.00000000000000000000000385155972720003511111999672}, + {0x98b3, -0.00000000000000000000000462704155146849855631194237}, + {0x98fa, -0.00000000000000000000000646234853557052870993288041}, + {0x998f, -0.00000000000000000000001478585344938536968832643037}, + {0x999b, -0.00000000000000000000001602662436821491120063354341}, + {0x99e4, -0.00000000000000000000002357464745776128873383514772}, + {0x9a8b, -0.00000000000000000000005748905257243542340356290410}, + {0x9a95, -0.00000000000000000000006162495563520056177791994756}, + {0x9a9f, -0.00000000000000000000006576085869796570015227699102}, + {0x9bb2, -0.00000000000000000000029447629806887785225422149438}, + {0x9bc4, -0.00000000000000000000032425480012078684854959220729}, + {0x9c09, -0.00000000000000000000045329497567905916582953196325}, + {0x9c79, -0.00000000000000000000082387189010281556417192305730}, + {0x9cca, -0.00000000000000000000133672386988569272259219644639}, + {0x9d21, -0.00000000000000000000213081725793659929046874879077}, + {0x9d2d, -0.00000000000000000000228963593554678060404405925965}, + {0x9d53, -0.00000000000000000000279256174797902143036587574443}, + {0x9d59, -0.00000000000000000000287197108678411208715353097887}, + {0x9d9b, -0.00000000000000000000410281583826301726736218711267}, + {0x9dc9, -0.00000000000000000000532042569994107400477290070739}, + {0x9e3a, -0.00000000000000000000984675801183124144166924907040}, + {0x9eb4, -0.00000000000000000001905824131322175762903725626529}, + {0x9eb7, -0.00000000000000000001937587866844212025618787720305}, + {0x9ec3, -0.00000000000000000002064642808932357076479036095407}, + {0x9ee0, -0.00000000000000000002371692252312040949391303001903}, + {0x9f0a, -0.00000000000000000002922263668027336169785712627345}, + {0x9f5b, -0.00000000000000000004637505386217294356399065691221}, + {0x9fb7, -0.00000000000000000007750351467376848102475150881219}, + {0xa007, -0.00000000000000000011434944787933054577422353759175}, + {0xa06c, -0.00000000000000000019989977555201488002012411016040}, + {0xa07a, -0.00000000000000000021175823681357508476708062516991}, + {0xa0c0, -0.00000000000000000032526065174565133020223584026098}, + {0xa0e5, -0.00000000000000000038794108984246955529329170531128}, + {0xa108, -0.00000000000000000046078592330633938445316744036973}, + {0xa128, -0.00000000000000000056920614055488982785391272045672}, + {0xa178, -0.00000000000000000084025668367626593635577592067420}, + {0xa1f6, -0.00000000000000000166696084019646306728645868133754}, + {0xa305, -0.00000000000000000720994444702860448614956112578511}, + {0xa312, -0.00000000000000000791467585914418236825440544635057}, + {0xa3f2, -0.00000000000000002623769257414920730298035778105259}, + {0xa3f3, -0.00000000000000002634611279139775774638110306113958}, + {0xa40a, -0.00000000000000002992397996059992237860569730401039}, + {0xa412, -0.00000000000000003165870343657672947301762178540230}, + {0xa413, -0.00000000000000003187554387107383035981911234557629}, + {0xa421, -0.00000000000000003491130995403324277503998018801212}, + {0xa450, -0.00000000000000004510281037539698445471003651618958}, + {0xa456, -0.00000000000000004640385298237958977551897987723351}, + {0xa4a4, -0.00000000000000007112366251504909087088890373706818}, + {0xa4c9, -0.00000000000000008716985466783455649419920518994331}, + {0xa4ed, -0.00000000000000010278236595162582034390652552247047}, + {0xa522, -0.00000000000000014051260155412137464736588299274445}, + {0xa53a, -0.00000000000000016132928326584305978030897676944733}, + {0xa541, -0.0000000000000001674008154317618846107507124543190}, + {0xa6f7, -0.00000000000000171390679426508540927898138761520386}, + {0xa76c, -0.00000000000000327515792264421179424971342086791992}, + {0xa7a3, -0.00000000000000452415882534751290222629904747009277}, + {0xa7ae, -0.00000000000000482947015711943095084279775619506836}, + {0xa7bc, -0.00000000000000521804821573823573999106884002685547}, + {0xa837, -0.00000000000001015854067532018234487622976303100586}, + {0xa83e, -0.00000000000001054711873393898713402450084686279297}, + {0xa881, -0.00000000000001432187701766451937146484851837158203}, + {0xa8b9, -0.00000000000002053912595556539599783718585968017578}, + {0xa8c2, -0.00000000000002153832667772803688421845436096191406}, + {0xa8ca, -0.00000000000002242650509742816211655735969543457031}, + {0xa8d6, -0.00000000000002375877272697834996506571769714355469}, + {0xa8e4, -0.00000000000002531308496145356912165880203247070312}, + {0xa92d, -0.00000000000003841371665203041629865765571594238281}, + {0xa933, -0.00000000000003974598428158060414716601371765136719}, + {0xa937, -0.00000000000004063416270128072937950491905212402344}, + {0xa950, -0.0000000000000461852778244065120816230773925781250}, + {0xa956, -0.00000000000004751754545395669993013143539428710938}, + {0xa9b8, -0.0000000000000817124146124115213751792907714843750}, + {0xa9c6, -0.00000000000008792966355031239800155162811279296875}, + {0xa9f0, -0.000000000000106581410364015027880668640136718750}, + {0xaa2d, -0.00000000000015365486660812166519463062286376953125}, + {0xaae6, -0.0000000000004085620730620576068758964538574218750}, + {0xaaff, -0.00000000000045297099404706386849284172058105468750}, + {0xab12, -0.000000000000518696197104873135685920715332031250}, + {0xab42, -0.000000000000689226453687297180294990539550781250}, + {0xabfa, -0.00000000000177635683940025046467781066894531250}, + {0xac3a, -0.0000000000026432189770275726914405822753906250}, + {0xadbb, -0.00000000002125943865394219756126403808593750}, + {0xadc3, -0.00000000002216893335571512579917907714843750}, + {0xadda, -0.0000000000247837306233122944831848144531250}, + {0xaddf, -0.00000000002535216481192037463188171386718750}, + {0xae16, -0.000000000034106051316484808921813964843750}, + {0xae77, -0.0000000000561612978344783186912536621093750}, + {0xaee2, -0.00000000010277290130034089088439941406250}, + {0xb03e, -0.00000000069121597334742546081542968750}, + {0xb050, -0.00000000075669959187507629394531250}, + {0xb075, -0.000000000891304807737469673156738281250}, + {0xb11d, -0.0000000022846506908535957336425781250}, + {0xb125, -0.0000000024010660126805305480957031250}, + {0xb139, -0.0000000026921043172478675842285156250}, + {0xb155, -0.0000000030995579436421394348144531250}, + {0xb18d, -0.000000004103640094399452209472656250}, + {0xb23c, -0.000000010943040251731872558593750}, + {0xb2a8, -0.0000000195577740669250488281250}, + {0xb341, -0.000000044936314225196838378906250}, + {0xb369, -0.000000054249539971351623535156250}, + {0xb37b, -0.000000058440491557121276855468750}, + {0xb3c6, -0.0000000922009348869323730468750}, + {0xb3c9, -0.00000009359791874885559082031250}, + {0xb3dc, -0.000000102445483207702636718750}, + {0xb3e2, -0.0000001052394509315490722656250}, + {0xb404, -0.00000012293457984924316406250}, + {0xb42d, -0.0000001611188054084777832031250}, + {0xb487, -0.000000251457095146179199218750}, + {0xb499, -0.000000284984707832336425781250}, + {0xb49b, -0.000000288709998130798339843750}, + {0xb4be, -0.00000035390257835388183593750}, + {0xb599, -0.0000011399388313293457031250}, + {0xb5be, -0.000001415610313415527343750}, + {0xb661, -0.000003352761268615722656250}, + {0xb67f, -0.000003799796104431152343750}, + {0xb6f4, -0.000007271766662597656250}, + {0xb70f, -0.0000085234642028808593750}, + {0xb729, -0.0000100731849670410156250}, + {0xb731, -0.0000105500221252441406250}, + {0xb735, -0.0000107884407043457031250}, + {0xb76f, -0.0000142455101013183593750}, + {0xb770, -0.000014305114746093750}, + {0xb7a4, -0.0000195503234863281250}, + {0xb7b1, -0.000021100044250488281250}, + {0xb829, -0.00004029273986816406250}, + {0xb882, -0.000061988830566406250}, + {0xb9a6, -0.0003166198730468750}, + {0xb9c5, -0.00037574768066406250}, + {0xb9cc, -0.000389099121093750}, + {0xb9d3, -0.00040245056152343750}, + {0xb9dd, -0.00042152404785156250}, + {0xbb04, -0.002014160156250}, + {0xbb14, -0.002258300781250}, + {0xbb19, -0.00233459472656250}, + {0xbb33, -0.00273132324218750}, + {0xbb66, -0.0035095214843750}, + {0xbbc5, -0.0060119628906250}, + {0xbc0d, -0.008605957031250}, + {0xbcb0, -0.0214843750}, + {0xbcc8, -0.02441406250}, + {0xbce0, -0.027343750}, + {0xbce8, -0.02832031250}, + {0xbd06, -0.032714843750}, + {0xbd77, -0.0603027343750}, + {0xbe31, -0.17285156250}, + {0xbe3a, -0.1816406250}, + {0xbe5d, -0.21582031250}, + {0xbe85, -0.2597656250}, + {0xbe9a, -0.300781250}, + {0xbea5, -0.3222656250}, + {0xbeb0, -0.343750}, + {0xbebf, -0.3730468750}, + {0xbeee, -0.464843750}, + {0xbf2b, -0.667968750}, + {0xbfac, -1.343750}, + {0xc022, -2.531250}, + {0xc026, -2.593750}, + {0xc05e, -3.468750}, + {0xc07e, -3.968750}, + {0xc07f, -3.9843750}, + {0xc086, -4.18750}, + {0xc0ae, -5.43750}, + {0xc0c2, -6.06250}, + {0xc0e6, -7.18750}, + {0xc13e, -11.8750}, + {0xc198, -19.0}, + {0xc1be, -23.750}, + {0xc1c1, -24.1250}, + {0xc1eb, -29.3750}, + {0xc225, -41.250}, + {0xc276, -61.50}, + {0xc27f, -63.750}, + {0xc29f, -79.50}, + {0xc313, -147.0}, + {0xc31b, -155.0}, + {0xc324, -164.0}, + {0xc35b, -219.0}, + {0xc394, -296.0}, + {0xc39d, -314.0}, + {0xc3b5, -362.0}, + {0xc3be, -380.0}, + {0xc429, -676.0}, + {0xc444, -784.0}, + {0xc44b, -812.0}, + {0xc4b5, -1448.0}, + {0xc4eb, -1880.0}, + {0xc523, -2608.0}, + {0xc557, -3440.0}, + {0xc55e, -3552.0}, + {0xc56d, -3792.0}, + {0xc58b, -4448.0}, + {0xc64d, -13120.0}, + {0xc6b8, -23552.0}, + {0xc6ca, -25856.0}, + {0xc777, -63232.0}, + {0xc7d6, -109568.0}, + {0xc868, -237568.0}, + {0xc8ca, -413696.0}, + {0xc910, -589824.0}, + {0xc9c5, -1613824.0}, + {0xc9c8, -1638400.0}, + {0xc9df, -1826816.0}, + {0xca3a, -3047424.0}, + {0xca42, -3178496.0}, + {0xca6b, -3850240.0}, + {0xcaa0, -5242880.0}, + {0xcaa2, -5308416.0}, + {0xcaac, -5636096.0}, + {0xcb3a, -12189696.0}, + {0xcb84, -17301504.0}, + {0xcc50, -54525952.0}, + {0xcc89, -71827456.0}, + {0xcc94, -77594624.0}, + {0xccaf, -91750400.0}, + {0xcce0, -117440512.0}, + {0xcce1, -117964800.0}, + {0xcd6d, -248512512.0}, + {0xcda8, -352321536.0}, + {0xcdba, -390070272.0}, + {0xcdd0, -436207616.0}, + {0xcde5, -480247808.0}, + {0xcdf7, -517996544.0}, + {0xce30, -738197504.0}, + {0xcec2, -1627389952.0}, + {0xcf03, -2197815296.0}, + {0xcf25, -2768240640.0}, + {0xcf57, -3607101440.0}, + {0xd036, -12213813248.0}, + {0xd09e, -21206401024.0}, + {0xd103, -35165044736.0}, + {0xd104, -35433480192.0}, + {0xd11f, -42681237504.0}, + {0xd125, -44291850240.0}, + {0xd19c, -83751862272.0}, + {0xd1c7, -106837311488.0}, + {0xd1cf, -111132278784.0}, + {0xd1d8, -115964116992.0}, + {0xd231, -190052302848.0}, + {0xd28a, -296352743424.0}, + {0xd294, -317827579904.0}, + {0xd2be, -408021893120.0}, + {0xd2c1, -414464344064.0}, + {0xd2c6, -425201762304.0}, + {0xd2db, -470298918912.0}, + {0xd334, -773094113280.0}, + {0xd36f, -1026497183744.0}, + {0xd375, -1052266987520.0}, + {0xd3c3, -1675037245440.0}, + {0xd3d5, -1829656068096.0}, + {0xd3f2, -2078764171264.0}, + {0xd44c, -3504693313536.0}, + {0xd49b, -5325759447040.0}, + {0xd4cd, -7043746365440.0}, + {0xd4e8, -7971459301376.0}, + {0xd538, -12644383719424.0}, + {0xd54c, -14018773254144.0}, + {0xd554, -14568529068032.0}, + {0xd5a3, -22402549415936.0}, + {0xd5bf, -26250840113152.0}, + {0xd64a, -55525337202688.0}, + {0xd74f, -227598906949632.0}, + {0xd75f, -245191092994048.0}, + {0xd762, -248489627877376.0}, + {0xd7d6, -470590976688128.0}, + {0xd7db, -481586092965888.0}, + {0xd819, -672901116198912.0}, + {0xd82d, -760862046420992.0}, + {0xd85b, -963172185931776.0}, + {0xd936, -3201777860083712.0}, + {0xd967, -4063794976260096.0}, + {0xd976, -4327677766926336.0}, + {0xd985, -4679521487814656.0}, + {0xd98c, -4925812092436480.0}, + {0xd9c9, -7072058789855232.0}, + {0xda1b, -10907155347537920.0}, + {0xda5a, -15340386230730752.0}, + {0xdab6, -25614222880669696.0}, + {0xdacc, -28710447624486912.0}, + {0xdb00, -36028797018963968.0}, + {0xdb0d, -39687971716202496.0}, + {0xdb24, -46161896180547584.0}, + {0xdb4c, -57420895248973824.0}, + {0xdbbd, -106397541196627968.0}, + {0xdbbf, -107523441103470592.0}, + {0xdbcc, -114841790497947648.0}, + {0xdbf4, -137359788634800128.0}, + {0xdc88, -306244774661193728.0}, + {0xdc9c, -351280770934898688.0}, + {0xdd24, -738590338888761344.0}, + {0xdd71, -1085367510196289536.0}, + {0xdd86, -1206964700135292928.0}, + {0xdd8f, -1288029493427961856.0}, + {0xdea9, -6088866696204910592.0}, + {0xded6, -7710162562058289152.0}, + {0xdedf, -8034421735228964864.0}, + {0xdef6, -8863084066665136128.0}, + {0xdf62, -16285016252571713536.0}, + {0xdf6a, -16861477004875137024.0}, + {0xdf71, -17365880163140632576.0}, + {0xdfc9, -28967152803247030272.0}, + {0xdff0, -34587645138205409280.0}, + {0xdff5, -35308221078584688640.0}, + {0xe037, -52746158835763249152.0}, + {0xe07e, -73210515542534782976.0}, + {0xe09a, -88774955854727217152.0}, + {0xe09d, -90504338111637487616.0}, + {0xe0cc, -117597993469898391552.0}, + {0xe113, -169479461177206505472.0}, + {0xe18f, -329735550317558235136.0}, + {0xe1aa, -391993311566327971840.0}, + {0xe21a, -710199646837817737216.0}, + {0xe233, -825491797298502434816.0}, + {0xe238, -848550227390639374336.0}, + {0xe247, -917725517667050192896.0}, + {0xe262, -1042241040164589666304.0}, + {0xe26a, -1079134528312008769536.0}, + {0xe271, -1111416330441000484864.0}, + {0xe28b, -1282048713122813837312.0}, + {0xe290, -1328165573307087716352.0}, + {0xe2b2, -1641760222560150093824.0}, + {0xe2eb, -2167492428660872314880.0}, + {0xe2f5, -2259726149029420072960.0}, + {0xe34b, -3744689046963038978048.0}, + {0xe363, -4187410904732068216832.0}, + {0xe388, -5017514388048998039552.0}, + {0xe3a0, -5902958103587056517120.0}, + {0xe3bd, -6972869259862210510848.0}, + {0xe3d6, -7895206463547688091648.0}, + {0xe3e4, -8411715297611555536896.0}, + {0xe406, -9887454823508319666176.0}, + {0xe5a3, -96218217088469021229056.0}, + {0xe5ae, -102711471002414783397888.0}, + {0xe5d7, -126913599227121715118080.0}, + {0xe5d9, -128094190847839126421504.0}, + {0xe644, -231395957660612615471104.0}, + {0xe66e, -280980805730743890214912.0}, + {0xe6b8, -434457716424007359660032.0}, + {0xe7f4, -2304514843640386864283648.0}, + {0xe824, -3097872412762487260184576.0}, + {0xe827, -3154540810556923002748928.0}, + {0xe880, -4835703278458516698824704.0}, + {0xe8f2, -9142501510835633133715456.0}, + {0xe928, -12693721105953606334414848.0}, + {0xe980, -19342813113834066795298816.0}, + {0xe9d1, -31583187037432187189198848.0}, + {0xe9fa, -37778931862957161709568000.0}, + {0xea15, -45032486780644936757805056.0}, + {0xea77, -74651169361203351538106368.0}, + {0xeaca, -122101507781077546645323776.0}, + {0xeacd, -123914896510499490407383040.0}, + {0xead9, -131168451428187265455620096.0}, + {0xeada, -131772914337994580042973184.0}, + {0xeb79, -301022529084042664501837824.0}, + {0xeba3, -394109817194369110954213376.0}, + {0xebf3, -587537948332709778907201536.0}, + {0xec09, -662491349148816787738984448.0}, + {0xec13, -710848381933401954727231488.0}, + {0xec48, -967140655691703339764940800.0}, + {0xece9, -2253437727761668781652312064.0}, + {0xed1f, -3075507285099616620452511744.0}, + {0xed24, -3172221350668786954429005824.0}, + {0xed9f, -6151014570199233240905023488.0}, + {0xede6, -8897694032363670725837455360.0}, + {0xedfe, -9826149061827705932011798528.0}, + {0xeea3, -25223028300439623101069656064.0}, + {0xeea5, -25532513310260968169794437120.0}, + {0xeec4, -30329530962491816735028543488.0}, + {0xeee0, -34662321099990647697175478272.0}, + {0xef23, -50446056600879246202139312128.0}, + {0xefa0, -99035203142830421991929937920.0}, + {0xefab, -105843873358900013503875121152.0}, + {0xf020, -198070406285660843983859875840.0}, + {0xf099, -378809652021326364119132012544.0}, + {0xf0a9, -418423733278458532915903987712.0}, + {0xf0cc, -505079536028435152158842683392.0}, + {0xf0e6, -569452418071274926453597143040.0}, + {0xf16b, -1163663636928257458405176770560.0}, + {0xf19b, -1535045648713871540874914037760.0}, + {0xf1bc, -1861861819085211933448282832896.0}, + {0xf1da, -2158967428513703199424072646656.0}, + {0xf231, -3505846191256196938514319802368.0}, + {0xf253, -4179285572627443808059443380224.0}, + {0xf287, -5347900969712842787564216647680.0}, + {0xf296, -5942112188569825319515796275200.0}, + {0xf2f3, -9626221745483117017615589965824.0}, + {0xf339, -14657210065138902454805630812160.0}, + {0xf36e, -18856302678394912347263460179968.0}, + {0xf3b3, -28363682180106632858488734220288.0}, + {0xf3c0, -30423614405477505635920876929024.0}, + {0xf3da, -34543478856219251190785162346496.0}, + {0xf3eb, -37237236381704238668965656657920.0}, + {0xf407, -42783207757702742300513733181440.0}, + {0xf418, -48170722808672717256874721804288.0}, + {0xf425, -52290587259414462811739007221760.0}, + {0xf473, -77009773963864936140924719726592.0}, + {0xf515, -188879939434006180823008777601024.0}, + {0xf57a, -316912650057057350374175801344000.0}, + {0xf594, -375224577667555902843024148791296.0}, + {0xf61c, -791013974542415146533942800154624.0}, + {0xf634, -912708432164325169077626307870720.0}, + {0xf69f, -1612451563490307798703806477238272.0}, + {0xf6b2, -1805134454724998667731305364455424.0}, + {0xf6f0, -2433889152438200450873670154321920.0}, + {0xf71e, -3204620717376963926983665703190528.0}, + {0xf735, -3671116138260952346734452482768896.0}, + {0xf748, -4056481920730334084789450257203200.0}, + {0xf75f, -4522977341614322504540237036781568.0}, + {0xf796, -6084722881095501127184175385804800.0}, + {0xf8a3, -26448262123161778232827215676964864.0}, + {0xf8d0, -33749929580476379585448226139930624.0}, + {0xf8ef, -38779967162181993850587144458862592.0}, + {0xf9a3, -105793048492647112931308862707859456.0}, + {0xfa0d, -183028464263352673905699995605008384.0}, + {0xfa18, -197307280624323449884158860510363648.0}, + {0xfa31, -229759135990166122562474462567989248.0}, + {0xfaa9, -438749084546192934610826939819098112.0}, + {0xfab5, -469902865697401900382009917794418688.0}, + {0xfacd, -532210427999819831924375873745059840.0}, + {0xfb84, -1370766370653194493932051030914105344.0}, + {0xfba8, -1744611744467702083186246766617952256.0}, + {0xfbfa, -2596148429267413814265248164610048000.0}, + {0xfc20, -3323069989462289682259517650700861440.0}, + {0xfc76, -5109220108798270386474008387952574464.0}, + {0xfc93, -6106141105636957291151863683162832896.0}, + {0xfccf, -8598443597733674552846501921188478976.0}, + {0xfcdc, -9138442471021296626213673539427368960.0}, + {0xfcf0, -9969209968386869046778552952102584320.0}, + {0xfd5b, -18193808192306036010370859137587216384.0}, + {0xfda9, -28079941410956347815092924148422279168.0}, + {0xfdca, -33563006893569125790821128272078700544.0}, + {0xfe1d, -52172198834557948011474427116003524608.0}, + {0xfe24, -54498347827181550789056089471494127616.0}, + {0xfe4b, -67458320786084480549868208309227487232.0}, + {0xfe4f, -68787548781869396422772015369507831808.0}, + {0xfe64, -75765995759740204755517002435979640832.0}, + {0xfea2, -107667467658578185705208371882707910656.0}, + {0xfec2, -128935115591136839671669284847193423872.0}, + {0xfee3, -150867377521587951574582101341819109376.0}, + {0xfee9, -154855061508942699193293522522660143104.0}, + {0xff57, -285784019093756912674318517960274083840.0}, + {0xff60, -297747071055821155530452781502797185024.0}, + {0xff8e, std::numeric_limits::quiet_NaN()}, + {0xffb0, std::numeric_limits::quiet_NaN()}, + {0xfffa, std::numeric_limits::quiet_NaN()}, + }; + return result; +} + +TEST_CASE(check_bf16_values) +{ + for(auto [x, f] : bf16_lut()) + { + + auto h = migraphx::bit_cast(x); + if(std::isnan(f)) + { + CHECK(std::isnan(h)); + } + else if(std::isinf(f)) + { + CHECK(std::isinf(h)); + CHECK((h < 0) == (f < 0)); + CHECK(bit_equal(x, migraphx::bf16(f))); + } + else + { + CHECK(bit_equal(x, migraphx::bf16(f))); + CHECK(migraphx::float_equal(float(h), f)); + } + } +} + +TEST_CASE(check_flows) +{ + // check positive underflow + CHECK(bit_equal(std::numeric_limits::min() * + std::numeric_limits::min(), + migraphx::bf16(0))); + + // check overflow + CHECK(bit_equal(std::numeric_limits::infinity() + + std::numeric_limits::infinity(), + std::numeric_limits::infinity())); + CHECK(bit_equal(std::numeric_limits::max() + + std::numeric_limits::max(), + std::numeric_limits::infinity())); + CHECK(bit_equal(std::numeric_limits::max() / + std::numeric_limits::epsilon(), + std::numeric_limits::infinity())); + CHECK(bit_equal(std::numeric_limits::max() + + std::numeric_limits::min(), + std::numeric_limits::max())); + + // check negative underflow + CHECK(bit_equal(std::numeric_limits::lowest() + + std::numeric_limits::lowest(), + -std::numeric_limits::infinity())); + CHECK(bit_equal(-std::numeric_limits::infinity() - + std::numeric_limits::infinity(), + -std::numeric_limits::infinity())); + CHECK(bit_equal(std::numeric_limits::lowest() - + std::numeric_limits::min(), + std::numeric_limits::lowest())); +} + +TEST_CASE(test_nan) +{ + float f_qnan = std::numeric_limits::quiet_NaN(); + migraphx::bf16 bf16_qnan(f_qnan); + EXPECT(bf16_qnan.is_nan()); + EXPECT(std::isnan(bf16_qnan)); + + float f_snan = std::numeric_limits::signaling_NaN(); + migraphx::bf16 bf16_snan(f_snan); + EXPECT(bf16_snan.is_nan()); + EXPECT(std::isnan(bf16_snan)); +} + +TEST_CASE(test_bool) +{ + float zero = 0.0; + float two = 2.0; + float other = -0.375; + migraphx::bf16 bf16_zero(zero); + migraphx::bf16 bf16_two(two); + migraphx::bf16 bf16_other(other); + EXPECT(not static_cast(bf16_zero)); + EXPECT(static_cast(bf16_two)); + EXPECT(static_cast(bf16_other)); +} + +TEST_CASE(test_pos_infinity) +{ + float finf = std::numeric_limits::infinity(); + migraphx::bf16 bf16_inf_1(finf); + CHECK(bit_equal(bf16_inf_1, std::numeric_limits::infinity())); +} + +TEST_CASE(test_neg_infinity) +{ + float finf = -1.0 * std::numeric_limits::infinity(); + migraphx::bf16 bf16_neginf_1(finf); + CHECK(bit_equal(bf16_neginf_1, -std::numeric_limits::infinity())); +} + +TEST_CASE(test_numeric_max_1) +{ + float fmax = std::numeric_limits::max(); // fp32 max is fp16 inf + migraphx::bf16 bf16_inf(fmax); + CHECK(bit_equal(bf16_inf, std::numeric_limits::max())); +} + +TEST_CASE(test_numeric_lowest_1) +{ + float flowest = std::numeric_limits::lowest(); + migraphx::bf16 bf16_neginf(flowest); + CHECK(bit_equal(bf16_neginf, std::numeric_limits::lowest())); +} + +TEST_CASE(test_max_eq_lowest) +{ + EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), + -1 * std::numeric_limits::max())); +} + +TEST_CASE(test_isfinite) +{ + EXPECT(std::isfinite(migraphx::bf16(0.0))); + EXPECT(std::isfinite(migraphx::bf16(-0.0))); + EXPECT(not std::isfinite(migraphx::bf16(std::numeric_limits::quiet_NaN()))); +} + +TEST_CASE(test_binary_ops) +{ + auto a = migraphx::bf16(-1.0); + auto b = migraphx::bf16(1.0); + auto c = migraphx::bf16(0.0); + auto d = migraphx::bf16(-0.0); + EXPECT(migraphx::float_equal((c + d), c)); + EXPECT(migraphx::float_equal((c + d), d)); + EXPECT(migraphx::float_equal((a + b), c)); + EXPECT(migraphx::float_equal((a + b), d)); + + auto e = migraphx::bf16(10.0); + auto f = migraphx::bf16(-10.0); + EXPECT(e > f); + EXPECT(f < e); + EXPECT(f <= e); + EXPECT(e >= f); + EXPECT(e <= e); + EXPECT(f >= f); + EXPECT(not migraphx::float_equal(f, e)); +} + +TEST_CASE(test_stream_op) +{ + auto a = migraphx::bf16(-1.0); + std::stringstream ss; + ss << a; + EXPECT(std::string("-1") == ss.str()); + ss = std::stringstream(); + auto b = std::numeric_limits::quiet_NaN(); + ss << b; + EXPECT(std::string("nan") == ss.str()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/common_dims.cpp b/test/common_dims.cpp index 1458822d45f..912b0f6ad33 100644 --- a/test/common_dims.cpp +++ b/test/common_dims.cpp @@ -63,9 +63,26 @@ TEST_CASE(common2) EXPECT(cd.dims == std::vector{2, 32, 40, 8, 8}); EXPECT(cd.axes_map1 == axes_map{{0}, {1, 2}, {3}, {4}}); EXPECT(cd.axes_map2 == axes_map{{0}, {1}, {2, 3, 4}}); + verify_common(cd); } +TEST_CASE(common3) +{ + auto cd = migraphx::common_dims::compute({2, 32, 4096}, {4, 16, 64, 64}); + EXPECT(cd.dims == std::vector{2, 2, 16, 64, 64}); + EXPECT(cd.axes_map1 == axes_map{{0}, {1, 2}, {3, 4}}); + EXPECT(cd.axes_map2 == axes_map{{0, 1}, {2}, {3}, {4}}); +} + +TEST_CASE(common4) +{ + auto cd = migraphx::common_dims::compute({4, 16, 64, 64}, {2, 32, 4096}); + EXPECT(cd.dims == std::vector{2, 2, 16, 64, 64}); + EXPECT(cd.axes_map1 == axes_map{{0, 1}, {2}, {3}, {4}}); + EXPECT(cd.axes_map2 == axes_map{{0}, {1, 2}, {3, 4}}); +} + TEST_CASE(common_same_dims) { auto cd = migraphx::common_dims::compute({{2, 32, 4}}, {64, 2, 2}); diff --git a/test/eliminate_contiguous_test.cpp b/test/eliminate_contiguous_test.cpp index e8ff253137f..a5c7f0a344e 100644 --- a/test/eliminate_contiguous_test.cpp +++ b/test/eliminate_contiguous_test.cpp @@ -167,7 +167,7 @@ TEST_CASE(non_standard_flatten_op) m.add_instruction(migraphx::make_op("flatten"), c); auto count = std::distance(m.begin(), m.end()); run_pass(m); - EXPECT(std::distance(m.begin(), m.end()) == count); + EXPECT(std::distance(m.begin(), m.end()) == (count - 1)); } TEST_CASE(standard_flatten_op) diff --git a/test/float32.cpp b/test/float32.cpp new file mode 100644 index 00000000000..cf6ad1f12ad --- /dev/null +++ b/test/float32.cpp @@ -0,0 +1,63 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include "test.hpp" +#include + +#include + +using fp32 = migraphx::generic_float<23, 8>; + +template +bool bit_equal(const T& x, const U& y) +{ + static_assert(sizeof(T) == sizeof(U)); + using type = std::array; + return migraphx::bit_cast(x) == migraphx::bit_cast(y); +} +// NOLINTNEXTLINE +#define MIGRAPHX_CHECK_FLOAT(x, y) \ + CHECK(bit_equal(x, y)); \ + CHECK(bit_equal(x, y.to_float())); \ + CHECK(bit_equal(fp32{x}, y)); \ + CHECK(bit_equal(fp32{x}.to_float(), y.to_float())) + +TEST_CASE(fp32_values_working) +{ + MIGRAPHX_CHECK_FLOAT(1.0f, fp32{1.0f}); + MIGRAPHX_CHECK_FLOAT(-1.0f, fp32{-1.0f}); + MIGRAPHX_CHECK_FLOAT(std::numeric_limits::min(), fp32::min()); + MIGRAPHX_CHECK_FLOAT(std::numeric_limits::lowest(), fp32::lowest()); + MIGRAPHX_CHECK_FLOAT(std::numeric_limits::max(), fp32::max()); + MIGRAPHX_CHECK_FLOAT(std::numeric_limits::epsilon(), fp32::epsilon()); + MIGRAPHX_CHECK_FLOAT(std::numeric_limits::denorm_min(), fp32::denorm_min()); + MIGRAPHX_CHECK_FLOAT(std::numeric_limits::infinity(), fp32::infinity()); + MIGRAPHX_CHECK_FLOAT(std::numeric_limits::quiet_NaN(), fp32::qnan()); + MIGRAPHX_CHECK_FLOAT(std::numeric_limits::signaling_NaN(), fp32::snan()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/float_equal.cpp b/test/float_equal.cpp index 68045632a19..4cd100e5725 100644 --- a/test/float_equal.cpp +++ b/test/float_equal.cpp @@ -73,17 +73,21 @@ TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); @@ -124,17 +128,21 @@ TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp index 93881036efa..69a698ddec0 100644 --- a/test/fuse_pointwise.cpp +++ b/test/fuse_pointwise.cpp @@ -427,6 +427,53 @@ TEST_CASE(add_reshape_add) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(add_transpose_reshape_add) +{ + migraphx::shape s1{migraphx::shape::float_type, {3, 16, 10}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 40, 2, 2}}; + migraphx::shape s3{migraphx::shape::float_type, {3, 10, 4, 2, 2}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto z = mm->add_parameter("z", s2); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), add1); + auto reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), transpose); + auto add2 = mm->add_instruction(migraphx::make_op("add"), reshape, z); + mm->add_return({add2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto z = mm->add_parameter("z", s2); + auto x2 = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2, 10}}}), x); + auto x3 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 4, 1, 2, 3}}}), x2); + auto y2 = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2, 10}}}), y); + auto y3 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 4, 1, 2, 3}}}), y2); + auto z2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), z); + auto fadd = + add_pointwise(p2, "main:pointwise0", {x3, y3, z2}, [=](auto* pm, const auto& inputs) { + auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); + }); + auto reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), fadd); + mm->add_return({reshape}); + } + EXPECT(p1.sort() == p2.sort()); +} + TEST_CASE(add_contiguous_reshape_add) { auto s1 = @@ -531,8 +578,8 @@ TEST_CASE(add_unsqueeze_add_nonstandard) auto x = mm->add_parameter("x", s1); auto y = mm->add_parameter("y", s1); auto z = mm->add_parameter("z", s2); - auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), x); - auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), y); + auto x2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), x); + auto y2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), y); auto fadd = add_pointwise(p2, "main:pointwise0", {x2, y2, z}, [=](auto* pm, const auto& inputs) { auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index 55fad66e5ac..90fd3413369 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -97,6 +97,44 @@ TEST_CASE(pointwise_reduce) EXPECT(p1 == p2); } +TEST_CASE(pointwise_reduce_unfusable_broadcast) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 1, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add")); + auto addb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), add); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), addb); + mm->add_return({rsum}); + } + run_pass(p1); + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add")); + auto addb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), add); + auto rsum = + add_reduce(p2, + "main:reduce_sum0", + {addb}, + {2}, + [&](auto* rm, const auto& inputs, const auto& axes) { + return rm->add_instruction( + migraphx::make_op("reduce_sum", {{"axes", axes}}), inputs[0]); + }); + mm->add_return({rsum}); + } + EXPECT(p1 == p2); +} + TEST_CASE(scalar_multibroadcast) { // Matches the find_pointwise_reduce matcher, but input x has a (scalar) shape @@ -283,6 +321,43 @@ TEST_CASE(reduce_pointwise) EXPECT(p1 == p2); } +TEST_CASE(reduce_pointwise_unfusable_broadcast) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 1, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), rsum); + auto yb = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), y); + auto add = add_pointwise(p1, "main:pointwise0", {rsumb, yb}, single_pointwise("add")); + mm->add_return({add}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto rsum = add_reduce( + p2, "main:reduce_sum0", {x}, {2}, [&](auto* rm, const auto& inputs, const auto& axes) { + return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + }); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), rsum); + auto yb = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), y); + auto add = add_pointwise(p2, "main:pointwise0", {rsumb, yb}, single_pointwise("add")); + mm->add_return({add}); + } + EXPECT(p1 == p2); +} + TEST_CASE(reduce_reduce) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; @@ -325,6 +400,57 @@ TEST_CASE(reduce_reduce) EXPECT(p1 == p2); } +TEST_CASE(reduce_reduce_unfusable_broadcast) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 1, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), rsum); + auto xb = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), x); + auto rsumdiff = add_pointwise(p1, "main:pointwise0", {rsumb, xb}, single_pointwise("sub")); + auto rsum2 = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), rsumdiff); + auto sqrt = add_pointwise(p1, "main:pointwise1", {rsum2}, single_pointwise("sqrt")); + mm->add_return({sqrt}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = add_reduce( + p2, "main:reduce_sum0", {x}, {2}, [&](auto* rm, const auto& inputs, const auto& axes) { + return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + }); + + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), rsum); + auto xb = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 3}}}), x); + + auto sqrt = add_reduce( + p2, + "main:pointwise0:main:reduce_sum1:main:pointwise1", + {rsumb, xb}, + {2}, + [&](auto* rm, const auto& inputs, const auto& axes) { + auto rsumdiff = add_pointwise( + p2, rm, "main:pointwise0", {inputs[0], inputs[1]}, single_pointwise("sub")); + auto rsum2 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + rsumdiff); + return add_pointwise(p2, rm, "main:pointwise1", {rsum2}, single_pointwise("sqrt")); + }); + mm->add_return({sqrt}); + } + EXPECT(p1 == p2); +} + TEST_CASE(parallel_reduce_reduce) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; @@ -954,7 +1080,7 @@ TEST_CASE(reshape_reduce_reduce_reduce_diff_axes) auto reduce0 = add_reduce( p2, "main:pointwise0:main:pointwise1:main:reduce_sum1:main:pointwise2:main:reduce_sum0:" - "main:pointwise3:main:pointwise4:main:pointwise5:main:pointwise6_reshape_reshape", + "main:pointwise3:main:pointwise4:main:pointwise5:main:pointwise6_reshape", {l2_mb, x1, x2, l1_mb}, {2}, [&](auto* rm, const auto& inputs, const auto& axes) { @@ -982,7 +1108,7 @@ TEST_CASE(reshape_reduce_reduce_reduce_diff_axes) auto reduce1 = add_reduce(p2, - "main:reduce_sum2_reshape", + "main:reduce_sum2", {reduce0}, {1}, [&](auto* rm, const auto& inputs, const auto& axes) { diff --git a/test/gpu/compile_hipblaslt.cpp b/test/gpu/compile_hipblaslt.cpp new file mode 100644 index 00000000000..e1ec7dc35ef --- /dev/null +++ b/test/gpu/compile_hipblaslt.cpp @@ -0,0 +1,80 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPBLASLT_GEMM); + +void run_lowering(migraphx::module& m, bool offload_copy = false) +{ + auto ctx = migraphx::gpu::context{}; + migraphx::run_passes(m, {migraphx::gpu::lowering{&ctx, offload_copy}}); +} + +#if MIGRAPHX_USE_HIPBLASLT +TEST_CASE(hipblaslt_op) +{ + if(migraphx::enabled(MIGRAPHX_ENABLE_HIPBLASLT_GEMM{}) and migraphx::gpu::hipblaslt_supported()) + { + migraphx::module m1; + { + migraphx::shape sa{migraphx::shape::float_type, {4, 2}}; + migraphx::shape sb{migraphx::shape::float_type, {2, 3}}; + migraphx::shape s_output{migraphx::shape::float_type, {4, 3}}; + auto a = m1.add_parameter("a", sa); + auto b = m1.add_parameter("b", sb); + migraphx::operation dot_op = migraphx::make_op("dot"); + m1.add_instruction(dot_op, a, b); + } + + run_lowering(m1); + migraphx::module m2; + { + auto a = m2.add_parameter("a", {migraphx::shape::float_type, {4, 2}}); + auto b = m2.add_parameter("b", {migraphx::shape::float_type, {2, 3}}); + + migraphx::shape output_shape{migraphx::shape::float_type, {4, 3}, {3, 1}}; + + // Add an allocate instruction for the output + auto output = m2.add_instruction(migraphx::op::allocate{output_shape, std::nullopt}); + + migraphx::op::dot dot_instance; + migraphx::gpu::hipblaslt_op hipblaslt_operator; + hipblaslt_operator.op = migraphx::gpu::hip_gemm{dot_instance, 1, 0}; + m2.add_instruction(hipblaslt_operator, a, b, output); + } + EXPECT(m1 == m2); + } +} +#endif + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/gpu/fuse_gemm.cpp b/test/gpu/fuse_gemm.cpp new file mode 100644 index 00000000000..cecf8d95990 --- /dev/null +++ b/test/gpu/fuse_gemm.cpp @@ -0,0 +1,106 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPBLASLT_GEMM) + +void run_lowering(migraphx::program& p, bool offload_copy = false) +{ + auto ctx = migraphx::gpu::context{}; + migraphx::run_passes( + *p.get_main_module(), + {migraphx::auto_contiguous{}, migraphx::gpu::lowering{&ctx, offload_copy}}); +} + +void run_fuse_ops(migraphx::program& p) +{ + migraphx::run_passes(p, {migraphx::gpu::fuse_ops{}, migraphx::dead_code_elimination{}}); +} + +#if MIGRAPHX_USE_HIPBLASLT +TEST_CASE(gemm_pointwise_add) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", s); + auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot, x}, single_pointwise("add")); + mm->add_return({add}); + } + run_lowering(p1); + run_fuse_ops(p1); + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", s); + + auto output = mm->add_instruction(migraphx::op::allocate{s, std::nullopt}); + + if(migraphx::enabled(MIGRAPHX_ENABLE_HIPBLASLT_GEMM{}) and + migraphx::gpu::hipblaslt_supported()) + { + migraphx::op::dot dot_instance; + migraphx::gpu::hipblaslt_op hipblaslt_operator; + hipblaslt_operator.op = migraphx::gpu::hip_gemm{dot_instance, 1, 1}; + auto add = mm->add_instruction(hipblaslt_operator, a, b, x, output); + mm->add_return({add}); + } + else + { + auto gemm_oper = + migraphx::make_op("gpu::gemm", {{"alpha", 1}, {"beta", 1}, {"compute_fp32", 1}}); + auto add = mm->add_instruction(gemm_oper, a, b, x, output); + mm->add_return({add}); + } + } + EXPECT(p1.sort() == p2.sort()); +} +#endif + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 06034f77976..7532cc5f914 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -445,6 +445,41 @@ TEST_CASE(relu_dot) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(relu_relu_dot) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto relux = add_pointwise(p1, "main:pointwise0", {x}, single_pointwise("relu")); + auto reluy = add_pointwise(p1, "main:pointwise1", {y}, single_pointwise("relu")); + auto dot = mm->add_instruction(migraphx::make_op("dot"), relux, reluy); + mm->add_return({dot}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto fused = + add_mlir(p2, + "main:pointwise0:main:pointwise1:mlir_dot0", + {x, y}, + {"x0", "x1"}, + [=](auto* pm, const auto& inputs) { + auto relux = pm->add_instruction(migraphx::make_op("relu"), inputs[0]); + auto reluy = pm->add_instruction(migraphx::make_op("relu"), inputs[1]); + auto dot = pm->add_instruction(migraphx::make_op("dot"), relux, reluy); + return std::make_tuple(dot->get_operator(), dot); + }); + mm->add_return({fused}); + } + EXPECT(p1.sort() == p2.sort()); +} + TEST_CASE(dequantizelinear_dot) { migraphx::program p1; @@ -528,6 +563,89 @@ TEST_CASE(dequantizelinear_dot) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(unsigned_dequantizelinear_dot) +{ + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::uint8_type, {2, 5, 2}}); + auto scalelit = + mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 2, 2}})); + auto zplit = + mm->add_literal(migraphx::generate_literal({migraphx::shape::uint8_type, {2, 2, 2}})); + + auto unsqueeze1 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), scalelit); + auto broadcast1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze1); + auto reshape1 = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast1); + auto scale = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), reshape1); + + auto unsqueeze2 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), zplit); + auto broadcast2 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze2); + auto reshape2 = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast2); + auto zp = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), reshape2); + + auto dq = add_pointwise( + p1, "main:pointwise0", {y, scale, zp}, single_pointwise("dequantizelinear")); + auto dot = mm->add_instruction(migraphx::make_op("dot"), x, dq); + mm->add_return({dot}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::uint8_type, {2, 5, 2}}); + auto scalelit = + mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 2, 2}})); + auto zplit = + mm->add_literal(migraphx::generate_literal({migraphx::shape::uint8_type, {2, 2, 2}})); + + auto fused = add_mlir( + p2, + "main:pointwise0:mlir_dot0", + {y, scalelit, zplit, x}, + {"x0", "x1", "x2", "x3"}, + [=](auto* pm, const auto& inputs) { + auto unsqueeze1 = + pm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), inputs[1]); + auto broadcast1 = pm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze1); + auto reshape1 = pm->add_instruction( + migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast1); + auto scale = pm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), + reshape1); + + auto unsqueeze2 = + pm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), inputs[2]); + auto broadcast2 = pm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze2); + auto reshape2 = pm->add_instruction( + migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast2); + auto zp = pm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), + reshape2); + + auto dq = pm->add_instruction( + migraphx::make_op("dequantizelinear"), inputs[0], scale, zp); + auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[3], dq); + return std::make_tuple(dot->get_operator(), dot); + }); + mm->add_return({fused}); + } + EXPECT(p1.sort() == p2.sort()); +} + TEST_CASE(unpack_int4_dot) { migraphx::program p1; diff --git a/test/gpu/fuse_ops.cpp b/test/gpu/fuse_ops.cpp index 30c994d3ad1..fc827de0947 100644 --- a/test/gpu/fuse_ops.cpp +++ b/test/gpu/fuse_ops.cpp @@ -103,8 +103,7 @@ TEST_CASE(layernorm_pointwise) { migraphx::program p1 = create_program(false); run_pass(p1); - migraphx::program p2 = create_fused_program(); - EXPECT(p1 == p2); + EXPECT(p1 == create_program(false)); } } diff --git a/test/gpu/jit.cpp b/test/gpu/jit.cpp index 9da7fbdd10d..65abfc03020 100644 --- a/test/gpu/jit.cpp +++ b/test/gpu/jit.cpp @@ -141,12 +141,47 @@ const std::string math_template = R"__migraphx__( #include namespace migraphx { + +template +struct test_implicit_conversion_op +{ + T x; + + template + constexpr operator vec() const + { + if constexpr(vec_size() == 0) + { + return x; + } + else + { + static_assert(vec_size() == N, "Vector mismatch size"); + return __builtin_convertvector(x, vec); + } + } + + template + constexpr operator U() const + { + static_assert(is_same{} or not is_same{} or is_same{}); + return static_cast(x); + } +}; + +template +constexpr test_implicit_conversion_op test_implicit_conversion(T x) +{ + return {x}; +} + + extern "C" { __global__ void kernel(${type}* p) { auto x = *p; - *p = migraphx::implicit_conversion(migraphx::${invoke}); - + *p = migraphx::test_implicit_conversion(migraphx::${invoke}); + (void)(1.f + migraphx::vec_at(migraphx::${invoke}, 0)); } } } @@ -209,11 +244,8 @@ TEST_CASE(compile_warnings) EXPECT(not compile({"-Wunused-parameter", "-Wno-error"}).empty()); EXPECT(not compile({"-Wno-unused-parameter", "-Werror"}).empty()); #ifdef MIGRAPHX_USE_HIPRTC - if(not migraphx::enabled(migraphx::gpu::MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS{})) - { - EXPECT(test::throws([&] { compile({"-Werror=unused-parameter"}); })); - EXPECT(test::throws([&] { compile({"-Wunused-parameter", "-Werror"}); })); - } + EXPECT(test::throws([&] { compile({"-Werror=unused-parameter"}); })); + EXPECT(test::throws([&] { compile({"-Wunused-parameter", "-Werror"}); })); #else EXPECT(test::throws([&] { compile({"-Werror=unused-parameter"}); })); EXPECT(test::throws([&] { compile({"-Wunused-parameter", "-Werror"}); })); @@ -266,12 +298,13 @@ TEST_CASE(compile_code_object_hip) { migraphx::shape input{migraphx::shape::float_type, {5, 2}}; migraphx::gpu::hip_compile_options options; + migraphx::gpu::context ctx; options.global = 256 * 1024; options.local = 1024; options.inputs = {input, input}; options.output = input; - auto co = migraphx::gpu::compile_hip_code_object(simple_pointwise_increment, options); + auto co = migraphx::gpu::compile_hip_code_object(ctx, simple_pointwise_increment, options); migraphx::program p; auto* mm = p.get_main_module(); @@ -350,7 +383,10 @@ TEST_CASE(compile_math) auto vec_sizes = {2, 4, 6}; for(auto&& t : migraphx::shape::types()) { - if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t)) + if(contains({migraphx::shape::bool_type, + migraphx::shape::tuple_type, + migraphx::shape::bf16_type}, + t)) continue; auto name = migraphx::shape::cpp_type(t); if(t == migraphx::shape::half_type) @@ -358,8 +394,9 @@ TEST_CASE(compile_math) data_types.push_back(name); // fp8 doesn't have vectorization support yet, therefore skip it for now. std::set fp8_types = {migraphx::shape::fp8e4m3fnuz_type, - migraphx::shape::fp8e5m2_type, - migraphx::shape::fp8e4m3fn_type}; + migraphx::shape::fp8e5m2fnuz_type, + migraphx::shape::fp8e4m3fn_type, + migraphx::shape::fp8e5m2_type}; if(not contains(fp8_types, t)) { migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) { @@ -369,6 +406,7 @@ TEST_CASE(compile_math) } migraphx::shape input{migraphx::shape::float_type, {5, 2}}; migraphx::gpu::hip_compile_options options; + migraphx::gpu::context ctx; options.global = 1024; options.local = 1024; options.inputs = {input}; @@ -377,7 +415,7 @@ TEST_CASE(compile_math) const auto& t = data_types[i % data_types.size()]; const auto& invoke = math_invoke[i / data_types.size()]; auto src = migraphx::interpolate_string(math_template, {{"type", t}, {"invoke", invoke}}); - auto co = migraphx::gpu::compile_hip_code_object(src, options); + auto co = migraphx::gpu::compile_hip_code_object(ctx, src, options); (void)co; }); } @@ -403,11 +441,12 @@ TEST_CASE(assert_type_min_max) { std::vector data_types; migraphx::gpu::hip_compile_options options; + migraphx::gpu::context ctx; for(auto&& t : migraphx::shape::types()) { if(contains({migraphx::shape::bool_type, - migraphx::shape::fp8e4m3fnuz_type, - migraphx::shape::tuple_type}, + migraphx::shape::tuple_type, + migraphx::shape::bf16_type}, t)) continue; auto name = migraphx::shape::cpp_type(t); @@ -444,7 +483,7 @@ TEST_CASE(assert_type_min_max) options.output = input; options.emplace_param("-Wno-float-equal"); - auto co = migraphx::gpu::compile_hip_code_object(src, options); + auto co = migraphx::gpu::compile_hip_code_object(ctx, src, options); }); } } diff --git a/test/gpu/mlir.cpp b/test/gpu/mlir.cpp index 5453f60004a..108bb941fd8 100644 --- a/test/gpu/mlir.cpp +++ b/test/gpu/mlir.cpp @@ -308,10 +308,10 @@ TEST_CASE(quant_dot_add) { std::string mlir_output = R"__migraphx__( module { - func.func @mlir_quant_dot_add(%arg0: !migraphx.shaped<1x5x4xi8, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xi8, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi32, 15x3x1>) -> !migraphx.shaped<1x5x3xi32, 15x3x1> attributes ${attrs} { - %0 = migraphx.quant_dot %arg0, %arg1 : <1x5x4xi8, 20x4x1>, <1x4x3xi8, 12x3x1> -> <1x5x3xi32, 15x3x1> - %1 = migraphx.add %0, %arg2 : <1x5x3xi32, 15x3x1>, <1x5x3xi32, 15x3x1> -> <1x5x3xi32, 15x3x1> - return %1 : !migraphx.shaped<1x5x3xi32, 15x3x1> + func.func @mlir_quant_dot_add(%arg0: !migraphx.shaped<1x5x4xsi8, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xsi8, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xsi32, 15x3x1>) -> !migraphx.shaped<1x5x3xsi32, 15x3x1> attributes ${attrs} { + %0 = migraphx.quant_dot %arg0, %arg1 : <1x5x4xsi8, 20x4x1>, <1x4x3xsi8, 12x3x1> -> <1x5x3xsi32, 15x3x1> + %1 = migraphx.add %0, %arg2 : <1x5x3xsi32, 15x3x1>, <1x5x3xsi32, 15x3x1> -> <1x5x3xsi32, 15x3x1> + return %1 : !migraphx.shaped<1x5x3xsi32, 15x3x1> } } )__migraphx__"; @@ -395,11 +395,11 @@ TEST_CASE(conv_int8_dequantize_quantize) { std::string mlir_output = R"__migraphx__( module { - func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: !migraphx.shaped<2x8x3x3xi8, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>, %arg2: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg3: !migraphx.shaped<1x2x2x2xi32, 8x4x2x1>) -> !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> attributes ${attrs} { - %0 = migraphx.quant_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8, 128x16x4x1>, <2x8x3x3xi8, 72x9x3x1> -> <1x2x2x2xi32, 8x4x2x1> - %1 = migraphx.dequantizelinear %0, %arg2, %arg3 : <1x2x2x2xi32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1> - %2 = migraphx.quantizelinear %1, %arg2, %arg3 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xi32, 8x4x2x1> - return %2 : !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> + func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: !migraphx.shaped<2x8x3x3xsi8, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xsi8, 128x16x4x1>, %arg2: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg3: !migraphx.shaped<1x2x2x2xsi32, 8x4x2x1>) -> !migraphx.shaped<1x2x2x2xsi32, 8x4x2x1> attributes ${attrs} { + %0 = migraphx.quant_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xsi8, 128x16x4x1>, <2x8x3x3xsi8, 72x9x3x1> -> <1x2x2x2xsi32, 8x4x2x1> + %1 = migraphx.dequantizelinear %0, %arg2, %arg3 : <1x2x2x2xsi32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xsi32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1> + %2 = migraphx.quantizelinear %1, %arg2, %arg3 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xsi32, 8x4x2x1> -> <1x2x2x2xsi32, 8x4x2x1> + return %2 : !migraphx.shaped<1x2x2x2xsi32, 8x4x2x1> } } )__migraphx__"; @@ -458,9 +458,9 @@ TEST_CASE(dot_where) { std::string mlir_output = R"__migraphx__( module { - func.func @mlir_dot_where(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi8, 15x3x1>, %arg3: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes ${attrs} { + func.func @mlir_dot_where(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xsi8, 15x3x1>, %arg3: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes ${attrs} { %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> - %1 = migraphx.where %arg2, %0, %arg3 : <1x5x3xi8, 15x3x1>, <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1> + %1 = migraphx.where %arg2, %0, %arg3 : <1x5x3xsi8, 15x3x1>, <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1> return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1> } } @@ -487,11 +487,11 @@ module { TEST_CASE(int4_unpack_ir) { std::string mlir_output = R"__migraphx__( -module { - func.func @mlir_unpack_int4(%arg0: !migraphx.shaped<2x1xi8, 1x1>) -> !migraphx.shaped<2x2xi8, 2x1> attributes ${attrs} { - %0 = migraphx.unpack %arg0 {axis = 1 : i64, isUnsigned = false} : <2x1xi8, 1x1> -> <2x2xi8, 2x1> - return %0 : !migraphx.shaped<2x2xi8, 2x1> - } +module { + func.func @mlir_unpack_int4(%arg0: !migraphx.shaped<2x1xsi8, 1x1>) -> !migraphx.shaped<2x2xsi8, 2x1> attributes ${attrs} { + %0 = migraphx.unpack %arg0 {axis = 1 : i64} : <2x1xsi8, 1x1> -> <2x2xsi8, 2x1> + return %0 : !migraphx.shaped<2x2xsi8, 2x1> + } } )__migraphx__"; migraphx::module m; @@ -513,12 +513,12 @@ module { TEST_CASE(int4_unpack_conv) { std::string mlir_output = R"__migraphx__( - module { - func.func @mlir_unpack_int4_quant_convolution(%arg0: !migraphx.shaped<2x8x2x1xi8, 16x2x1x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>) -> !migraphx.shaped<1x2x3x3xi32, 18x9x3x1> attributes ${attrs} { - %0 = migraphx.unpack %arg0 {axis = 3 : i64, isUnsigned = false} : <2x8x2x1xi8, 16x2x1x1> -> <2x8x2x2xi8, 32x4x2x1> - %1 = migraphx.quant_convolution %arg1, %0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8, 128x16x4x1>, <2x8x2x2xi8, 32x4x2x1> -> <1x2x3x3xi32, 18x9x3x1> - return %1 : !migraphx.shaped<1x2x3x3xi32, 18x9x3x1> - } +module { + func.func @mlir_unpack_int4_quant_convolution(%arg0: !migraphx.shaped<2x8x2x1xsi8, 16x2x1x1>, %arg1: !migraphx.shaped<1x8x4x4xsi8, 128x16x4x1>) -> !migraphx.shaped<1x2x3x3xsi32, 18x9x3x1> attributes ${attrs} { + %0 = migraphx.unpack %arg0 {axis = 3 : i64} : <2x8x2x1xsi8, 16x2x1x1> -> <2x8x2x2xsi8, 32x4x2x1> + %1 = migraphx.quant_convolution %arg1, %0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xsi8, 128x16x4x1>, <2x8x2x2xsi8, 32x4x2x1> -> <1x2x3x3xsi32, 18x9x3x1> + return %1 : !migraphx.shaped<1x2x3x3xsi32, 18x9x3x1> + } } )__migraphx__"; migraphx::module m; @@ -537,4 +537,116 @@ TEST_CASE(int4_unpack_conv) EXPECT(verify_mlir(m)); } +TEST_CASE(int4_unpack_dequantizelinear) +{ + std::string mlir_output = R"__migraphx__( +module { + func.func @mlir_unsqueeze_reshape_slice_unsqueeze_reshape_slice_unpack_int4_dequantizelinear_dot(%arg0: !migraphx.shaped<2x3x5xf32, 15x5x1>, %arg1: !migraphx.shaped<2x5x1xsi8, 5x1x1>, %arg2: !migraphx.shaped<2x2x2xf32, 4x2x1>, %arg3: !migraphx.shaped<2x2x2xsi8, 4x2x1>) -> !migraphx.shaped<2x3x2xf32, 6x2x1> attributes ${attrs} { + %0 = migraphx.reshape %arg2 {dims = [2, 2, 1, 2]} : <2x2x2xf32, 4x2x1> -> <2x2x1x2xf32, 4x2x2x1> + %1 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [2, 2, 3, 2]} : <2x2x1x2xf32, 4x2x2x1> -> <2x2x3x2xf32, 4x2x0x1> + %2 = migraphx.reshape %1 {dims = [2, 6, 2]} : <2x2x3x2xf32, 4x2x0x1> -> <2x6x2xf32, 12x2x1> + %3 = migraphx.slice %2 {axes = [1], ends = [5], starts = [0]} : <2x6x2xf32, 12x2x1> -> <2x5x2xf32, 12x2x1> + %4 = migraphx.reshape %arg3 {dims = [2, 2, 1, 2]} : <2x2x2xsi8, 4x2x1> -> <2x2x1x2xsi8, 4x2x2x1> + %5 = migraphx.multibroadcast %4 {out_dyn_dims = [], out_lens = [2, 2, 3, 2]} : <2x2x1x2xsi8, 4x2x2x1> -> <2x2x3x2xsi8, 4x2x0x1> + %6 = migraphx.reshape %5 {dims = [2, 6, 2]} : <2x2x3x2xsi8, 4x2x0x1> -> <2x6x2xsi8, 12x2x1> + %7 = migraphx.slice %6 {axes = [1], ends = [5], starts = [0]} : <2x6x2xsi8, 12x2x1> -> <2x5x2xsi8, 12x2x1> + %8 = migraphx.unpack %arg1 {axis = 2 : i64} : <2x5x1xsi8, 5x1x1> -> <2x5x2xsi8, 10x2x1> + %9 = migraphx.dequantizelinear %8, %3, %7 : <2x5x2xsi8, 10x2x1>, <2x5x2xf32, 12x2x1>, !migraphx.shaped<2x5x2xsi8, 12x2x1> -> <2x5x2xf32, 10x2x1> + %10 = migraphx.dot %arg0, %9 : <2x3x5xf32, 15x5x1>, <2x5x2xf32, 10x2x1> -> <2x3x2xf32, 6x2x1> + return %10 : !migraphx.shaped<2x3x2xf32, 6x2x1> + } +} +)__migraphx__"; + migraphx::module m; + auto x0 = m.add_parameter("x0", migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}); + auto x1 = m.add_parameter("x1", migraphx::shape{migraphx::shape::int8_type, {2, 5, 1}}); + auto x2 = m.add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); + auto x3 = m.add_parameter("x3", migraphx::shape{migraphx::shape::int8_type, {2, 2, 2}}); + + auto unsqueeze1 = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), x2); + auto broadcast1 = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze1); + auto reshape1 = + m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast1); + auto scale = m.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), reshape1); + + auto unsqueeze2 = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), x3); + auto broadcast2 = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze2); + auto reshape2 = + m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast2); + auto zp = m.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), reshape2); + + auto unpack = m.add_instruction(migraphx::make_op("unpack_int4"), x1); + auto dq = m.add_instruction(migraphx::make_op("dequantizelinear"), unpack, scale, zp); + auto dot = m.add_instruction(migraphx::make_op("dot"), x0, dq); + m.add_return({dot}); + auto s = migraphx::gpu::dump_mlir(m); + // Skip test if MLIR is not enabled + if(s.empty()) + return; + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); + EXPECT(verify_mlir(m)); +} + +TEST_CASE(uint4_unpack_dequantizelinear) +{ + std::string mlir_output = R"__migraphx__( +module { + func.func @mlir_unsqueeze_reshape_slice_unsqueeze_reshape_slice_unpack_int4_dequantizelinear_dot(%arg0: !migraphx.shaped<2x3x5xf32, 15x5x1>, %arg1: !migraphx.shaped<2x5x1xui8, 5x1x1>, %arg2: !migraphx.shaped<2x2x2xf32, 4x2x1>, %arg3: !migraphx.shaped<2x2x2xui8, 4x2x1>) -> !migraphx.shaped<2x3x2xf32, 6x2x1> attributes ${attrs} { + %0 = migraphx.reshape %arg2 {dims = [2, 2, 1, 2]} : <2x2x2xf32, 4x2x1> -> <2x2x1x2xf32, 4x2x2x1> + %1 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [2, 2, 3, 2]} : <2x2x1x2xf32, 4x2x2x1> -> <2x2x3x2xf32, 4x2x0x1> + %2 = migraphx.reshape %1 {dims = [2, 6, 2]} : <2x2x3x2xf32, 4x2x0x1> -> <2x6x2xf32, 12x2x1> + %3 = migraphx.slice %2 {axes = [1], ends = [5], starts = [0]} : <2x6x2xf32, 12x2x1> -> <2x5x2xf32, 12x2x1> + %4 = migraphx.reshape %arg3 {dims = [2, 2, 1, 2]} : <2x2x2xui8, 4x2x1> -> <2x2x1x2xui8, 4x2x2x1> + %5 = migraphx.multibroadcast %4 {out_dyn_dims = [], out_lens = [2, 2, 3, 2]} : <2x2x1x2xui8, 4x2x2x1> -> <2x2x3x2xui8, 4x2x0x1> + %6 = migraphx.reshape %5 {dims = [2, 6, 2]} : <2x2x3x2xui8, 4x2x0x1> -> <2x6x2xui8, 12x2x1> + %7 = migraphx.slice %6 {axes = [1], ends = [5], starts = [0]} : <2x6x2xui8, 12x2x1> -> <2x5x2xui8, 12x2x1> + %8 = migraphx.unpack %arg1 {axis = 2 : i64} : <2x5x1xui8, 5x1x1> -> <2x5x2xui8, 10x2x1> + %9 = migraphx.dequantizelinear %8, %3, %7 : <2x5x2xui8, 10x2x1>, <2x5x2xf32, 12x2x1>, !migraphx.shaped<2x5x2xui8, 12x2x1> -> <2x5x2xf32, 10x2x1> + %10 = migraphx.dot %arg0, %9 : <2x3x5xf32, 15x5x1>, <2x5x2xf32, 10x2x1> -> <2x3x2xf32, 6x2x1> + return %10 : !migraphx.shaped<2x3x2xf32, 6x2x1> + } +} +)__migraphx__"; + migraphx::module m; + auto x0 = m.add_parameter("x0", migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}); + auto x1 = m.add_parameter("x1", migraphx::shape{migraphx::shape::uint8_type, {2, 5, 1}}); + auto x2 = m.add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); + auto x3 = m.add_parameter("x3", migraphx::shape{migraphx::shape::uint8_type, {2, 2, 2}}); + + auto unsqueeze1 = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), x2); + auto broadcast1 = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze1); + auto reshape1 = + m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast1); + auto scale = m.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), reshape1); + + auto unsqueeze2 = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), x3); + auto broadcast2 = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze2); + auto reshape2 = + m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast2); + auto zp = m.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), reshape2); + + auto unpack = m.add_instruction(migraphx::make_op("unpack_int4"), x1); + auto dq = m.add_instruction(migraphx::make_op("dequantizelinear"), unpack, scale, zp); + auto dot = m.add_instruction(migraphx::make_op("dot"), x0, dq); + m.add_return({dot}); + auto s = migraphx::gpu::dump_mlir(m); + // Skip test if MLIR is not enabled + if(s.empty()) + return; + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); + EXPECT(verify_mlir(m)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/half.cpp b/test/half.cpp new file mode 100644 index 00000000000..6b0a5f330a4 --- /dev/null +++ b/test/half.cpp @@ -0,0 +1,1243 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include "test.hpp" + +#include +#include +#include + +template +bool bit_equal(const T& x, const U& y) +{ + static_assert(sizeof(T) == sizeof(U)); + using type = std::array; + return migraphx::bit_cast(x) == migraphx::bit_cast(y); +} + +TEST_CASE(check_numeric_limits) +{ + CHECK(bit_equal(std::numeric_limits::min(), uint16_t{0x0400})); + CHECK(bit_equal(std::numeric_limits::lowest(), uint16_t{0xfbff})); + CHECK(bit_equal(std::numeric_limits::max(), uint16_t{0x7bff})); + CHECK(bit_equal(std::numeric_limits::epsilon(), uint16_t{0x1400})); + CHECK(bit_equal(std::numeric_limits::denorm_min(), uint16_t{0x0001})); + CHECK(bit_equal(std::numeric_limits::infinity(), uint16_t{0x7c00})); + CHECK(bit_equal(std::numeric_limits::quiet_NaN(), uint16_t{0x7e00})); + CHECK(bit_equal(std::numeric_limits::signaling_NaN(), uint16_t{0x7d00})); +} + +const std::map& half_lut() // NOLINT(readability-function-size) +{ + static const std::map result = { + {0x0000, 0}, + {0x0058, 0.0000052452087402}, + {0x0079, 0.0000072121620178}, + {0x0097, 0.0000090003013611}, + {0x009e, 0.0000094175338745}, + {0x0125, 0.0000174641609192}, + {0x0167, 0.0000213980674744}, + {0x0196, 0.0000241994857788}, + {0x01c4, 0.0000269412994385}, + {0x01c8, 0.0000271797180176}, + {0x0236, 0.0000337362289429}, + {0x029f, 0.0000399947166443}, + {0x02bf, 0.0000419020652771}, + {0x02d6, 0.0000432729721069}, + {0x03a6, 0.0000556707382202}, + {0x03b7, 0.0000566840171814}, + {0x03d4, 0.0000584125518799}, + {0x03d8, 0.000058650970459}, + {0x03ed, 0.0000599026679993}, + {0x0427, 0.0000633597373962}, + {0x0430, 0.0000638961791992}, + {0x0435, 0.0000641942024231}, + {0x0454, 0.0000660419464111}, + {0x047a, 0.0000683069229126}, + {0x04b6, 0.0000718832015991}, + {0x056a, 0.0000826120376587}, + {0x056f, 0.0000829100608826}, + {0x0584, 0.0000841617584229}, + {0x05a1, 0.0000858902931213}, + {0x05a4, 0.0000860691070557}, + {0x05b8, 0.0000872611999512}, + {0x05bc, 0.0000874996185303}, + {0x0635, 0.0000947117805481}, + {0x0641, 0.0000954270362854}, + {0x0686, 0.0000995397567749}, + {0x0694, 0.0001003742218018}, + {0x06db, 0.0001046061515808}, + {0x0725, 0.0001090168952942}, + {0x0777, 0.0001139044761658}, + {0x07b2, 0.0001174211502075}, + {0x0812, 0.0001242160797119}, + {0x082e, 0.0001275539398193}, + {0x0859, 0.00013267993927}, + {0x0895, 0.0001398324966431}, + {0x08af, 0.0001429319381714}, + {0x08fc, 0.0001521110534668}, + {0x092e, 0.0001580715179443}, + {0x0971, 0.0001660585403442}, + {0x0991, 0.0001698732376099}, + {0x09ca, 0.0001766681671143}, + {0x0a63, 0.0001949071884155}, + {0x0a8e, 0.0002000331878662}, + {0x0a93, 0.000200629234314}, + {0x0b2a, 0.0002186298370361}, + {0x0b3a, 0.0002205371856689}, + {0x0b3c, 0.000220775604248}, + {0x0b4e, 0.00022292137146}, + {0x0bae, 0.0002343654632568}, + {0x0bff, 0.0002440214157104}, + {0x0c08, 0.0002460479736328}, + {0x0c56, 0.0002646446228027}, + {0x0c61, 0.0002672672271729}, + {0x0c70, 0.0002708435058594}, + {0x0c7c, 0.0002737045288086}, + {0x0cd8, 0.0002956390380859}, + {0x0cdd, 0.0002968311309814}, + {0x0d05, 0.0003063678741455}, + {0x0d61, 0.0003283023834229}, + {0x0d85, 0.0003368854522705}, + {0x0d8c, 0.0003385543823242}, + {0x0d90, 0.0003395080566406}, + {0x0d9e, 0.000342845916748}, + {0x0da5, 0.0003445148468018}, + {0x0dda, 0.0003571510314941}, + {0x0dde, 0.0003581047058105}, + {0x0df6, 0.000363826751709}, + {0x0eec, 0.000422477722168}, + {0x0f1c, 0.0004339218139648}, + {0x0f99, 0.0004637241363525}, + {0x0fac, 0.0004682540893555}, + {0x0fb0, 0.0004692077636719}, + {0x0ff5, 0.0004856586456299}, + {0x107f, 0.0005488395690918}, + {0x1096, 0.0005598068237305}, + {0x10c8, 0.0005836486816406}, + {0x10e9, 0.0005993843078613}, + {0x110a, 0.000615119934082}, + {0x118a, 0.000676155090332}, + {0x11b5, 0.0006966590881348}, + {0x1293, 0.0008025169372559}, + {0x133f, 0.0008845329284668}, + {0x1342, 0.0008859634399414}, + {0x1372, 0.0009088516235352}, + {0x13cf, 0.000953197479248}, + {0x140c, 0.0009880065917969}, + {0x1437, 0.0010290145874023}, + {0x14a3, 0.0011320114135742}, + {0x14a6, 0.0011348724365234}, + {0x14b2, 0.0011463165283203}, + {0x14ba, 0.0011539459228516}, + {0x14d9, 0.0011835098266602}, + {0x14da, 0.0011844635009766}, + {0x14e7, 0.0011968612670898}, + {0x14fe, 0.0012187957763672}, + {0x1521, 0.0012521743774414}, + {0x153d, 0.0012788772583008}, + {0x15ad, 0.0013856887817383}, + {0x15fd, 0.0014619827270508}, + {0x1649, 0.0015344619750977}, + {0x1658, 0.0015487670898438}, + {0x168a, 0.0015964508056641}, + {0x169d, 0.0016145706176758}, + {0x16b3, 0.0016355514526367}, + {0x16c9, 0.0016565322875977}, + {0x16d1, 0.0016641616821289}, + {0x16e0, 0.001678466796875}, + {0x170a, 0.0017185211181641}, + {0x176d, 0.0018129348754883}, + {0x185b, 0.0021266937255859}, + {0x185e, 0.0021324157714844}, + {0x187e, 0.0021934509277344}, + {0x18ca, 0.0023384094238281}, + {0x18e9, 0.0023975372314453}, + {0x1901, 0.0024433135986328}, + {0x191e, 0.0024986267089844}, + {0x1963, 0.0026302337646484}, + {0x199f, 0.0027446746826172}, + {0x19b2, 0.0027809143066406}, + {0x19d4, 0.0028457641601562}, + {0x1a31, 0.0030231475830078}, + {0x1a4a, 0.0030708312988281}, + {0x1a7a, 0.0031623840332031}, + {0x1ace, 0.0033226013183594}, + {0x1b03, 0.0034236907958984}, + {0x1b22, 0.0034828186035156}, + {0x1d49, 0.0051612854003906}, + {0x1d5a, 0.0052261352539062}, + {0x1d6c, 0.0052947998046875}, + {0x1e02, 0.0058670043945312}, + {0x1e19, 0.0059547424316406}, + {0x1e4c, 0.0061492919921875}, + {0x1eb3, 0.0065422058105469}, + {0x1f32, 0.0070266723632812}, + {0x1f36, 0.0070419311523438}, + {0x1f41, 0.0070838928222656}, + {0x1f7a, 0.0073013305664062}, + {0x1f8d, 0.0073738098144531}, + {0x200b, 0.0078964233398438}, + {0x205f, 0.0085372924804688}, + {0x2060, 0.008544921875}, + {0x2067, 0.0085983276367188}, + {0x20e2, 0.0095367431640625}, + {0x2164, 0.010528564453125}, + {0x22a4, 0.012969970703125}, + {0x22b4, 0.013092041015625}, + {0x22f2, 0.0135650634765625}, + {0x230c, 0.013763427734375}, + {0x2314, 0.013824462890625}, + {0x2341, 0.0141677856445312}, + {0x2356, 0.0143280029296875}, + {0x236e, 0.0145111083984375}, + {0x2371, 0.0145339965820312}, + {0x23cd, 0.0152359008789062}, + {0x2405, 0.0157012939453125}, + {0x24a2, 0.018096923828125}, + {0x24ba, 0.018463134765625}, + {0x24e7, 0.0191497802734375}, + {0x266c, 0.02508544921875}, + {0x26a2, 0.025909423828125}, + {0x26cc, 0.02655029296875}, + {0x26f0, 0.027099609375}, + {0x271e, 0.027801513671875}, + {0x2798, 0.0296630859375}, + {0x287d, 0.035064697265625}, + {0x28a2, 0.03619384765625}, + {0x28ca, 0.03741455078125}, + {0x2933, 0.040618896484375}, + {0x298d, 0.043365478515625}, + {0x299e, 0.04388427734375}, + {0x29c0, 0.044921875}, + {0x29c2, 0.04498291015625}, + {0x29cf, 0.045379638671875}, + {0x29fa, 0.04669189453125}, + {0x2a06, 0.04705810546875}, + {0x2aa5, 0.051910400390625}, + {0x2bcb, 0.060882568359375}, + {0x2c18, 0.06396484375}, + {0x2c65, 0.06866455078125}, + {0x2c66, 0.0687255859375}, + {0x2c93, 0.07147216796875}, + {0x2d24, 0.080322265625}, + {0x2d35, 0.08135986328125}, + {0x2d4c, 0.082763671875}, + {0x2db7, 0.08929443359375}, + {0x2dec, 0.092529296875}, + {0x2e31, 0.09674072265625}, + {0x2ec9, 0.10601806640625}, + {0x2f85, 0.11749267578125}, + {0x2f94, 0.118408203125}, + {0x302b, 0.1302490234375}, + {0x3094, 0.14306640625}, + {0x3096, 0.143310546875}, + {0x30ae, 0.146240234375}, + {0x30b9, 0.1475830078125}, + {0x310c, 0.15771484375}, + {0x31bd, 0.1793212890625}, + {0x3213, 0.1898193359375}, + {0x325b, 0.1986083984375}, + {0x32aa, 0.208251953125}, + {0x32c0, 0.2109375}, + {0x32d7, 0.2137451171875}, + {0x3391, 0.2364501953125}, + {0x340d, 0.253173828125}, + {0x343d, 0.264892578125}, + {0x3566, 0.33740234375}, + {0x35e6, 0.36865234375}, + {0x35f4, 0.3720703125}, + {0x363b, 0.389404296875}, + {0x363e, 0.39013671875}, + {0x3650, 0.39453125}, + {0x3698, 0.412109375}, + {0x36e7, 0.431396484375}, + {0x36fe, 0.43701171875}, + {0x374a, 0.45556640625}, + {0x3760, 0.4609375}, + {0x3761, 0.461181640625}, + {0x379e, 0.47607421875}, + {0x37cc, 0.4873046875}, + {0x37fd, 0.499267578125}, + {0x3828, 0.51953125}, + {0x3841, 0.53173828125}, + {0x3877, 0.55810546875}, + {0x38a4, 0.580078125}, + {0x38d3, 0.60302734375}, + {0x39b2, 0.7119140625}, + {0x3a60, 0.796875}, + {0x3aa3, 0.82958984375}, + {0x3aa6, 0.8310546875}, + {0x3ac9, 0.84814453125}, + {0x3acf, 0.85107421875}, + {0x3b14, 0.884765625}, + {0x3b42, 0.9072265625}, + {0x3b5c, 0.919921875}, + {0x3bde, 0.9833984375}, + {0x3c67, 1.1005859375}, + {0x3cb5, 1.1767578125}, + {0x3cca, 1.197265625}, + {0x3cdd, 1.2158203125}, + {0x3cfc, 1.24609375}, + {0x3d1f, 1.2802734375}, + {0x3e0c, 1.51171875}, + {0x3e1c, 1.52734375}, + {0x3e5b, 1.5888671875}, + {0x3e7f, 1.6240234375}, + {0x3eae, 1.669921875}, + {0x3efe, 1.748046875}, + {0x3f3e, 1.810546875}, + {0x3f9d, 1.9033203125}, + {0x400a, 2.01953125}, + {0x4070, 2.21875}, + {0x40a0, 2.3125}, + {0x40ce, 2.40234375}, + {0x40e6, 2.44921875}, + {0x410e, 2.52734375}, + {0x4129, 2.580078125}, + {0x4144, 2.6328125}, + {0x41a4, 2.8203125}, + {0x41f3, 2.974609375}, + {0x42f1, 3.470703125}, + {0x438f, 3.779296875}, + {0x43b0, 3.84375}, + {0x43c3, 3.880859375}, + {0x43de, 3.93359375}, + {0x4483, 4.51171875}, + {0x44f8, 4.96875}, + {0x4505, 5.01953125}, + {0x45dd, 5.86328125}, + {0x45f3, 5.94921875}, + {0x460e, 6.0546875}, + {0x46ce, 6.8046875}, + {0x4704, 7.015625}, + {0x471a, 7.1015625}, + {0x475e, 7.3671875}, + {0x4761, 7.37890625}, + {0x479f, 7.62109375}, + {0x47ca, 7.7890625}, + {0x47db, 7.85546875}, + {0x47fc, 7.984375}, + {0x481e, 8.234375}, + {0x4839, 8.4453125}, + {0x483d, 8.4765625}, + {0x48ac, 9.34375}, + {0x48da, 9.703125}, + {0x4919, 10.1953125}, + {0x4950, 10.625}, + {0x4987, 11.0546875}, + {0x49bb, 11.4609375}, + {0x4a14, 12.15625}, + {0x4a92, 13.140625}, + {0x4b25, 14.2890625}, + {0x4b81, 15.0078125}, + {0x4b99, 15.1953125}, + {0x4bbe, 15.484375}, + {0x4bf8, 15.9375}, + {0x4c1f, 16.484375}, + {0x4c49, 17.140625}, + {0x4d21, 20.515625}, + {0x4d4a, 21.15625}, + {0x4d51, 21.265625}, + {0x4de2, 23.53125}, + {0x4e05, 24.078125}, + {0x4ea3, 26.546875}, + {0x4eb0, 26.75}, + {0x4f0e, 28.21875}, + {0x4f4a, 29.15625}, + {0x4f6b, 29.671875}, + {0x4fa6, 30.59375}, + {0x4fae, 30.71875}, + {0x4ff6, 31.84375}, + {0x503c, 33.875}, + {0x50e4, 39.125}, + {0x514e, 42.4375}, + {0x516b, 43.34375}, + {0x51d3, 46.59375}, + {0x5213, 48.59375}, + {0x526e, 51.4375}, + {0x52a6, 53.1875}, + {0x52b4, 53.625}, + {0x52b6, 53.6875}, + {0x52bc, 53.875}, + {0x5300, 56}, + {0x5389, 60.28125}, + {0x5406, 64.375}, + {0x5498, 73.5}, + {0x54bd, 75.8125}, + {0x54cf, 76.9375}, + {0x5502, 80.125}, + {0x558e, 88.875}, + {0x5597, 89.4375}, + {0x55eb, 94.6875}, + {0x55f6, 95.375}, + {0x5629, 98.5625}, + {0x562b, 98.6875}, + {0x5635, 99.3125}, + {0x564e, 100.875}, + {0x5671, 103.0625}, + {0x5681, 104.0625}, + {0x56d1, 109.0625}, + {0x571c, 113.75}, + {0x5756, 117.375}, + {0x5790, 121}, + {0x57fd, 127.8125}, + {0x582d, 133.625}, + {0x5869, 141.125}, + {0x58ab, 149.375}, + {0x58ad, 149.625}, + {0x58c9, 153.125}, + {0x58f7, 158.875}, + {0x5904, 160.5}, + {0x59c2, 184.25}, + {0x59e6, 188.75}, + {0x5a88, 209}, + {0x5ada, 219.25}, + {0x5aef, 221.875}, + {0x5af5, 222.625}, + {0x5b7f, 239.875}, + {0x5ba4, 244.5}, + {0x5c08, 258}, + {0x5cbf, 303.75}, + {0x5d4d, 339.25}, + {0x5dc2, 368.5}, + {0x5dc4, 369}, + {0x5e31, 396.25}, + {0x5e38, 398}, + {0x5e7c, 415}, + {0x5e8d, 419.25}, + {0x5ead, 427.25}, + {0x5eb4, 429}, + {0x5ec0, 432}, + {0x5eef, 443.75}, + {0x5f04, 449}, + {0x5f41, 464.25}, + {0x5f58, 470}, + {0x5f61, 472.25}, + {0x5f77, 477.75}, + {0x5f7b, 478.75}, + {0x6029, 532.5}, + {0x6046, 547}, + {0x6055, 554.5}, + {0x60a8, 596}, + {0x60d7, 619.5}, + {0x6139, 668.5}, + {0x6167, 691.5}, + {0x61b5, 730.5}, + {0x61c0, 736}, + {0x61e6, 755}, + {0x625b, 813.5}, + {0x62c4, 866}, + {0x62fd, 894.5}, + {0x62fe, 895}, + {0x6332, 921}, + {0x636a, 949}, + {0x6374, 954}, + {0x6376, 955}, + {0x639f, 975.5}, + {0x63d6, 1003}, + {0x6417, 1047}, + {0x642e, 1070}, + {0x6431, 1073}, + {0x644f, 1103}, + {0x6459, 1113}, + {0x645b, 1115}, + {0x6480, 1152}, + {0x648d, 1165}, + {0x649f, 1183}, + {0x64bb, 1211}, + {0x6516, 1302}, + {0x6571, 1393}, + {0x6585, 1413}, + {0x65aa, 1450}, + {0x660c, 1548}, + {0x6694, 1684}, + {0x66d0, 1744}, + {0x6721, 1825}, + {0x672d, 1837}, + {0x6734, 1844}, + {0x6766, 1894}, + {0x6773, 1907}, + {0x677d, 1917}, + {0x679a, 1946}, + {0x690f, 2590}, + {0x6934, 2664}, + {0x6955, 2730}, + {0x697d, 2810}, + {0x698e, 2844}, + {0x6a3a, 3188}, + {0x6a63, 3270}, + {0x6a67, 3278}, + {0x6a7c, 3320}, + {0x6a87, 3342}, + {0x6b07, 3598}, + {0x6b11, 3618}, + {0x6b36, 3692}, + {0x6b3c, 3704}, + {0x6b75, 3818}, + {0x6b88, 3856}, + {0x6be6, 4044}, + {0x6bee, 4060}, + {0x6c62, 4488}, + {0x6c8b, 4652}, + {0x6d30, 5312}, + {0x6d48, 5408}, + {0x6ddd, 6004}, + {0x6de9, 6052}, + {0x6e39, 6372}, + {0x6e7e, 6648}, + {0x6ea5, 6804}, + {0x6ec5, 6932}, + {0x6ee1, 7044}, + {0x6ef1, 7108}, + {0x6fa2, 7816}, + {0x6fbc, 7920}, + {0x704c, 8800}, + {0x7083, 9240}, + {0x7108, 10304}, + {0x7115, 10408}, + {0x7128, 10560}, + {0x71af, 11640}, + {0x7222, 12560}, + {0x7228, 12608}, + {0x72a5, 13608}, + {0x72e0, 14080}, + {0x72e6, 14128}, + {0x731e, 14576}, + {0x7377, 15288}, + {0x741d, 16848}, + {0x7423, 16944}, + {0x7424, 16960}, + {0x7466, 18016}, + {0x74b0, 19200}, + {0x74ce, 19680}, + {0x74f0, 20224}, + {0x754b, 21680}, + {0x7575, 22352}, + {0x7594, 22848}, + {0x75b1, 23312}, + {0x7614, 24896}, + {0x7618, 24960}, + {0x7631, 25360}, + {0x7660, 26112}, + {0x76c8, 27776}, + {0x7773, 30512}, + {0x77af, 31472}, + {0x77b9, 31632}, + {0x77de, 32224}, + {0x7844, 34944}, + {0x78d2, 39488}, + {0x7924, 42112}, + {0x793b, 42848}, + {0x79db, 47968}, + {0x7a0f, 49632}, + {0x7a1a, 49984}, + {0x7a6c, 52608}, + {0x7a99, 54048}, + {0x7ada, 56128}, + {0x7b0f, 57824}, + {0x7b15, 58016}, + {0x7b41, 59424}, + {0x7b51, 59936}, + {0x7b9c, 62336}, + {0x7ba3, 62560}, + {0x7c00, std::numeric_limits::infinity()}, + {0x7c05, std::numeric_limits::quiet_NaN()}, + {0x7c0e, std::numeric_limits::quiet_NaN()}, + {0x7c3e, std::numeric_limits::quiet_NaN()}, + {0x7c4e, std::numeric_limits::quiet_NaN()}, + {0x7c55, std::numeric_limits::quiet_NaN()}, + {0x7c58, std::numeric_limits::quiet_NaN()}, + {0x7c66, std::numeric_limits::quiet_NaN()}, + {0x7cc9, std::numeric_limits::quiet_NaN()}, + {0x7cd8, std::numeric_limits::quiet_NaN()}, + {0x7d2d, std::numeric_limits::quiet_NaN()}, + {0x7d60, std::numeric_limits::quiet_NaN()}, + {0x7d79, std::numeric_limits::quiet_NaN()}, + {0x7dc7, std::numeric_limits::quiet_NaN()}, + {0x7dcf, std::numeric_limits::quiet_NaN()}, + {0x7dd8, std::numeric_limits::quiet_NaN()}, + {0x7dfb, std::numeric_limits::quiet_NaN()}, + {0x7e0f, std::numeric_limits::quiet_NaN()}, + {0x7e56, std::numeric_limits::quiet_NaN()}, + {0x7e89, std::numeric_limits::quiet_NaN()}, + {0x7e9c, std::numeric_limits::quiet_NaN()}, + {0x7eb2, std::numeric_limits::quiet_NaN()}, + {0x7ec3, std::numeric_limits::quiet_NaN()}, + {0x7ef9, std::numeric_limits::quiet_NaN()}, + {0x7f36, std::numeric_limits::quiet_NaN()}, + {0x8040, -0.0000038146972656}, + {0x8101, -0.0000153183937073}, + {0x813d, -0.0000188946723938}, + {0x81a8, -0.0000252723693848}, + {0x81bc, -0.0000264644622803}, + {0x81c2, -0.0000268220901489}, + {0x8259, -0.00003582239151}, + {0x8330, -0.0000486373901367}, + {0x8366, -0.0000518560409546}, + {0x8392, -0.0000544786453247}, + {0x83e4, -0.0000593662261963}, + {0x83ee, -0.000059962272644}, + {0x8402, -0.0000611543655396}, + {0x845e, -0.0000666379928589}, + {0x84ac, -0.0000712871551514}, + {0x84b1, -0.0000715851783752}, + {0x84fb, -0.0000759959220886}, + {0x8546, -0.0000804662704468}, + {0x856f, -0.0000829100608826}, + {0x85b5, -0.0000870823860168}, + {0x8638, -0.0000948905944824}, + {0x8656, -0.0000966787338257}, + {0x86b9, -0.0001025795936584}, + {0x86ba, -0.0001026391983032}, + {0x86fe, -0.0001066923141479}, + {0x8731, -0.0001097321510315}, + {0x8740, -0.0001106262207031}, + {0x8793, -0.0001155734062195}, + {0x87bd, -0.0001180768013}, + {0x87f1, -0.0001211762428284}, + {0x87f4, -0.0001213550567627}, + {0x8809, -0.000123143196106}, + {0x882a, -0.0001270771026611}, + {0x8848, -0.0001306533813477}, + {0x8852, -0.0001318454742432}, + {0x8874, -0.0001358985900879}, + {0x8892, -0.0001394748687744}, + {0x88a7, -0.000141978263855}, + {0x88c8, -0.0001459121704102}, + {0x8927, -0.0001572370529175}, + {0x892a, -0.0001575946807861}, + {0x8989, -0.0001689195632935}, + {0x89b9, -0.0001746416091919}, + {0x8b18, -0.0002164840698242}, + {0x8b4b, -0.0002225637435913}, + {0x8b62, -0.000225305557251}, + {0x8b7f, -0.0002287626266479}, + {0x8bca, -0.0002377033233643}, + {0x8bcf, -0.000238299369812}, + {0x8bff, -0.0002440214157104}, + {0x8c0b, -0.0002467632293701}, + {0x8c55, -0.0002644062042236}, + {0x8c63, -0.0002677440643311}, + {0x8d53, -0.0003249645233154}, + {0x8dba, -0.0003495216369629}, + {0x8e03, -0.0003669261932373}, + {0x8e82, -0.0003972053527832}, + {0x8e9c, -0.0004034042358398}, + {0x8faa, -0.0004677772521973}, + {0x902f, -0.0005106925964355}, + {0x9051, -0.0005269050598145}, + {0x9066, -0.0005369186401367}, + {0x907e, -0.0005483627319336}, + {0x9080, -0.00054931640625}, + {0x908e, -0.0005559921264648}, + {0x9102, -0.0006113052368164}, + {0x91eb, -0.0007224082946777}, + {0x9215, -0.0007424354553223}, + {0x9252, -0.0007715225219727}, + {0x9294, -0.0008029937744141}, + {0x9297, -0.0008044242858887}, + {0x933d, -0.0008835792541504}, + {0x936f, -0.0009074211120605}, + {0x93aa, -0.0009355545043945}, + {0x93f2, -0.0009698867797852}, + {0x941d, -0.0010042190551758}, + {0x945a, -0.0010623931884766}, + {0x94ad, -0.0011415481567383}, + {0x94d2, -0.0011768341064453}, + {0x951c, -0.0012474060058594}, + {0x9520, -0.001251220703125}, + {0x952f, -0.0012655258178711}, + {0x953f, -0.0012807846069336}, + {0x9549, -0.0012903213500977}, + {0x95c6, -0.0014095306396484}, + {0x9602, -0.0014667510986328}, + {0x969b, -0.001612663269043}, + {0x96fa, -0.0017032623291016}, + {0x977d, -0.0018281936645508}, + {0x97c3, -0.0018949508666992}, + {0x97c6, -0.0018978118896484}, + {0x97db, -0.001917839050293}, + {0x97f9, -0.0019464492797852}, + {0x983f, -0.0020732879638672}, + {0x984e, -0.0021018981933594}, + {0x985a, -0.0021247863769531}, + {0x988c, -0.0022201538085938}, + {0x990d, -0.0024662017822266}, + {0x9958, -0.0026092529296875}, + {0x9971, -0.0026569366455078}, + {0x9a4e, -0.0030784606933594}, + {0x9a8f, -0.0032024383544922}, + {0x9abe, -0.0032920837402344}, + {0x9ace, -0.0033226013183594}, + {0x9b1e, -0.0034751892089844}, + {0x9b3e, -0.0035362243652344}, + {0x9b77, -0.0036449432373047}, + {0x9b89, -0.0036792755126953}, + {0x9b90, -0.003692626953125}, + {0x9bec, -0.0038681030273438}, + {0x9c03, -0.0039176940917969}, + {0x9c75, -0.0043525695800781}, + {0x9d6c, -0.0052947998046875}, + {0x9d74, -0.0053253173828125}, + {0x9da7, -0.0055198669433594}, + {0x9e73, -0.0062980651855469}, + {0x9e94, -0.0064239501953125}, + {0x9f17, -0.0069236755371094}, + {0x9f3a, -0.0070571899414062}, + {0x9f6c, -0.0072479248046875}, + {0x9f89, -0.0073585510253906}, + {0x9fbd, -0.0075569152832031}, + {0xa003, -0.0078353881835938}, + {0xa014, -0.007965087890625}, + {0xa019, -0.0080032348632812}, + {0xa01d, -0.0080337524414062}, + {0xa090, -0.0089111328125}, + {0xa1cf, -0.0113449096679688}, + {0xa1dd, -0.0114517211914062}, + {0xa249, -0.0122756958007812}, + {0xa26d, -0.0125503540039062}, + {0xa288, -0.01275634765625}, + {0xa2fb, -0.0136337280273438}, + {0xa390, -0.0147705078125}, + {0xa3b3, -0.0150375366210938}, + {0xa3ed, -0.0154800415039062}, + {0xa434, -0.01641845703125}, + {0xa476, -0.017425537109375}, + {0xa571, -0.0212554931640625}, + {0xa57d, -0.0214385986328125}, + {0xa597, -0.0218353271484375}, + {0xa5d1, -0.0227203369140625}, + {0xa5f9, -0.0233306884765625}, + {0xa680, -0.025390625}, + {0xa6e3, -0.0269012451171875}, + {0xa6f0, -0.027099609375}, + {0xa72d, -0.0280303955078125}, + {0xa77e, -0.029266357421875}, + {0xa7d0, -0.030517578125}, + {0xa7ee, -0.030975341796875}, + {0xa7f3, -0.0310516357421875}, + {0xa80c, -0.0316162109375}, + {0xa827, -0.032440185546875}, + {0xa89f, -0.036102294921875}, + {0xa8a0, -0.0361328125}, + {0xa8a5, -0.036285400390625}, + {0xa948, -0.041259765625}, + {0xaa0c, -0.0472412109375}, + {0xaa16, -0.04754638671875}, + {0xaa9a, -0.05157470703125}, + {0xaaeb, -0.054046630859375}, + {0xab5c, -0.0574951171875}, + {0xac7e, -0.0701904296875}, + {0xad33, -0.08123779296875}, + {0xad37, -0.08148193359375}, + {0xad90, -0.0869140625}, + {0xada0, -0.087890625}, + {0xade5, -0.09210205078125}, + {0xadf8, -0.09326171875}, + {0xae02, -0.0938720703125}, + {0xae04, -0.093994140625}, + {0xae4f, -0.09857177734375}, + {0xae63, -0.09979248046875}, + {0xaebe, -0.1053466796875}, + {0xaee1, -0.10748291015625}, + {0xaef9, -0.10894775390625}, + {0xaf0b, -0.11004638671875}, + {0xaf78, -0.11669921875}, + {0xaf7d, -0.11700439453125}, + {0xaf7f, -0.11712646484375}, + {0xaf8c, -0.117919921875}, + {0xafcb, -0.12176513671875}, + {0xb06b, -0.1380615234375}, + {0xb07b, -0.1400146484375}, + {0xb088, -0.1416015625}, + {0xb0b2, -0.146728515625}, + {0xb0ed, -0.1539306640625}, + {0xb0f9, -0.1553955078125}, + {0xb16c, -0.16943359375}, + {0xb189, -0.1729736328125}, + {0xb1c5, -0.1802978515625}, + {0xb1f7, -0.1864013671875}, + {0xb22d, -0.1929931640625}, + {0xb23c, -0.19482421875}, + {0xb258, -0.1982421875}, + {0xb2c7, -0.2117919921875}, + {0xb2de, -0.214599609375}, + {0xb2e1, -0.2149658203125}, + {0xb317, -0.2215576171875}, + {0xb31d, -0.2222900390625}, + {0xb3ef, -0.2479248046875}, + {0xb3f8, -0.2490234375}, + {0xb45a, -0.27197265625}, + {0xb548, -0.330078125}, + {0xb5d8, -0.365234375}, + {0xb64e, -0.39404296875}, + {0xb69f, -0.413818359375}, + {0xb6e6, -0.43115234375}, + {0xb6ed, -0.432861328125}, + {0xb6f7, -0.435302734375}, + {0xb79a, -0.47509765625}, + {0xb7b6, -0.48193359375}, + {0xb7ee, -0.49560546875}, + {0xb856, -0.5419921875}, + {0xb8c0, -0.59375}, + {0xb96f, -0.67919921875}, + {0xb9a5, -0.70556640625}, + {0xba1e, -0.7646484375}, + {0xba2d, -0.77197265625}, + {0xba48, -0.78515625}, + {0xba65, -0.79931640625}, + {0xbaaf, -0.83544921875}, + {0xbab0, -0.8359375}, + {0xbb12, -0.8837890625}, + {0xbb35, -0.90087890625}, + {0xbb47, -0.90966796875}, + {0xbb97, -0.94873046875}, + {0xbba3, -0.95458984375}, + {0xbbcb, -0.97412109375}, + {0xbbe8, -0.98828125}, + {0xbbee, -0.9912109375}, + {0xbd03, -1.2529296875}, + {0xbd4b, -1.3232421875}, + {0xbd4c, -1.32421875}, + {0xbd8a, -1.384765625}, + {0xbdb6, -1.427734375}, + {0xbde1, -1.4697265625}, + {0xbe04, -1.50390625}, + {0xbe50, -1.578125}, + {0xbe54, -1.58203125}, + {0xbe6a, -1.603515625}, + {0xbf31, -1.7978515625}, + {0xbf87, -1.8818359375}, + {0xbfa2, -1.908203125}, + {0xc016, -2.04296875}, + {0xc074, -2.2265625}, + {0xc0ca, -2.39453125}, + {0xc100, -2.5}, + {0xc1b7, -2.857421875}, + {0xc1b9, -2.861328125}, + {0xc1d3, -2.912109375}, + {0xc23f, -3.123046875}, + {0xc2d5, -3.416015625}, + {0xc32f, -3.591796875}, + {0xc3e3, -3.943359375}, + {0xc412, -4.0703125}, + {0xc49a, -4.6015625}, + {0xc4ca, -4.7890625}, + {0xc4cf, -4.80859375}, + {0xc523, -5.13671875}, + {0xc55d, -5.36328125}, + {0xc5aa, -5.6640625}, + {0xc604, -6.015625}, + {0xc61b, -6.10546875}, + {0xc642, -6.2578125}, + {0xc68b, -6.54296875}, + {0xc69e, -6.6171875}, + {0xc6b0, -6.6875}, + {0xc6ca, -6.7890625}, + {0xc71e, -7.1171875}, + {0xc721, -7.12890625}, + {0xc73b, -7.23046875}, + {0xc7d4, -7.828125}, + {0xc831, -8.3828125}, + {0xc89a, -9.203125}, + {0xc8be, -9.484375}, + {0xc8dc, -9.71875}, + {0xc8e4, -9.78125}, + {0xc8fa, -9.953125}, + {0xc8fe, -9.984375}, + {0xc969, -10.8203125}, + {0xca0f, -12.1171875}, + {0xca1a, -12.203125}, + {0xca6f, -12.8671875}, + {0xca7b, -12.9609375}, + {0xca8f, -13.1171875}, + {0xcaca, -13.578125}, + {0xcafd, -13.9765625}, + {0xcb05, -14.0390625}, + {0xcb6b, -14.8359375}, + {0xcbaf, -15.3671875}, + {0xcbb4, -15.40625}, + {0xcbdf, -15.7421875}, + {0xcc2d, -16.703125}, + {0xcc74, -17.8125}, + {0xccac, -18.6875}, + {0xcd11, -20.265625}, + {0xce04, -24.0625}, + {0xce0f, -24.234375}, + {0xceaf, -26.734375}, + {0xceb8, -26.875}, + {0xcf36, -28.84375}, + {0xcfad, -30.703125}, + {0xd019, -32.78125}, + {0xd08d, -36.40625}, + {0xd115, -40.65625}, + {0xd119, -40.78125}, + {0xd128, -41.25}, + {0xd1a4, -45.125}, + {0xd1b7, -45.71875}, + {0xd1b8, -45.75}, + {0xd203, -48.09375}, + {0xd20a, -48.3125}, + {0xd28b, -52.34375}, + {0xd2ac, -53.375}, + {0xd2ae, -53.4375}, + {0xd2c5, -54.15625}, + {0xd2f2, -55.5625}, + {0xd326, -57.1875}, + {0xd337, -57.71875}, + {0xd343, -58.09375}, + {0xd34e, -58.4375}, + {0xd40c, -64.75}, + {0xd43b, -67.6875}, + {0xd45a, -69.625}, + {0xd464, -70.25}, + {0xd4c3, -76.1875}, + {0xd505, -80.3125}, + {0xd52d, -82.8125}, + {0xd5cf, -92.9375}, + {0xd5f0, -95}, + {0xd607, -96.4375}, + {0xd635, -99.3125}, + {0xd63d, -99.8125}, + {0xd644, -100.25}, + {0xd658, -101.5}, + {0xd789, -120.5625}, + {0xd863, -140.375}, + {0xd866, -140.75}, + {0xd884, -144.5}, + {0xd88d, -145.625}, + {0xd89b, -147.375}, + {0xd8da, -155.25}, + {0xd93b, -167.375}, + {0xd982, -176.25}, + {0xd995, -178.625}, + {0xd99d, -179.625}, + {0xd9cf, -185.875}, + {0xdaaf, -213.875}, + {0xdabd, -215.625}, + {0xdb54, -234.5}, + {0xdc10, -260}, + {0xdca1, -296.25}, + {0xdd0a, -322.5}, + {0xdd56, -341.5}, + {0xddcf, -371.75}, + {0xde04, -385}, + {0xde0d, -387.25}, + {0xde3d, -399.25}, + {0xde4f, -403.75}, + {0xde66, -409.5}, + {0xdeae, -427.5}, + {0xdf52, -468.5}, + {0xdf63, -472.75}, + {0xdf6a, -474.5}, + {0xdf77, -477.75}, + {0xdf7b, -478.75}, + {0xdfc5, -497.25}, + {0xdfcf, -499.75}, + {0xdfd2, -500.5}, + {0xdfd8, -502}, + {0xdfe1, -504.25}, + {0xe022, -529}, + {0xe046, -547}, + {0xe092, -585}, + {0xe0b0, -600}, + {0xe0be, -607}, + {0xe0f4, -634}, + {0xe11b, -653.5}, + {0xe19c, -718}, + {0xe213, -777.5}, + {0xe232, -793}, + {0xe25b, -813.5}, + {0xe262, -817}, + {0xe279, -828.5}, + {0xe2cc, -870}, + {0xe2da, -877}, + {0xe326, -915}, + {0xe330, -920}, + {0xe3c3, -993.5}, + {0xe3cc, -998}, + {0xe566, -1382}, + {0xe57e, -1406}, + {0xe5c8, -1480}, + {0xe609, -1545}, + {0xe628, -1576}, + {0xe663, -1635}, + {0xe6ac, -1708}, + {0xe710, -1808}, + {0xe77f, -1919}, + {0xe7e7, -2023}, + {0xe868, -2256}, + {0xe885, -2314}, + {0xe8ea, -2516}, + {0xe919, -2610}, + {0xe92c, -2648}, + {0xea60, -3264}, + {0xeac1, -3458}, + {0xeacb, -3478}, + {0xeb22, -3652}, + {0xeb2c, -3672}, + {0xeb59, -3762}, + {0xeba5, -3914}, + {0xec53, -4428}, + {0xec97, -4700}, + {0xed16, -5208}, + {0xed4a, -5416}, + {0xed69, -5540}, + {0xee14, -6224}, + {0xee59, -6500}, + {0xee8a, -6696}, + {0xee93, -6732}, + {0xeed7, -7004}, + {0xef0b, -7212}, + {0xef59, -7524}, + {0xef61, -7556}, + {0xef67, -7580}, + {0xefb6, -7896}, + {0xf03a, -8656}, + {0xf04e, -8816}, + {0xf05f, -8952}, + {0xf09f, -9464}, + {0xf0c0, -9728}, + {0xf173, -11160}, + {0xf1d7, -11960}, + {0xf225, -12584}, + {0xf2ca, -13904}, + {0xf2d8, -14016}, + {0xf2e5, -14120}, + {0xf317, -14520}, + {0xf35d, -15080}, + {0xf3bd, -15848}, + {0xf3d3, -16024}, + {0xf3e6, -16176}, + {0xf3fb, -16344}, + {0xf477, -18288}, + {0xf4e0, -19968}, + {0xf4e5, -20048}, + {0xf50b, -20656}, + {0xf5a2, -23072}, + {0xf5c1, -23568}, + {0xf634, -25408}, + {0xf651, -25872}, + {0xf68a, -26784}, + {0xf69c, -27072}, + {0xf6ce, -27872}, + {0xf816, -33472}, + {0xf849, -35104}, + {0xf869, -36128}, + {0xf878, -36608}, + {0xf8cf, -39392}, + {0xf90a, -41280}, + {0xf916, -41664}, + {0xf91e, -41920}, + {0xf9c1, -47136}, + {0xfa0a, -49472}, + {0xfa11, -49696}, + {0xfa1d, -50080}, + {0xfa51, -51744}, + {0xfa86, -53440}, + {0xfaac, -54656}, + {0xfb95, -62112}, + {0xfbd1, -64032}, + {0xfbe0, -64512}, + {0xfbf5, -65184}, + {0xfc00, -std::numeric_limits::infinity()}, + {0xfca5, std::numeric_limits::quiet_NaN()}, + {0xfcb9, std::numeric_limits::quiet_NaN()}, + {0xfcc6, std::numeric_limits::quiet_NaN()}, + {0xfd72, std::numeric_limits::quiet_NaN()}, + {0xfd77, std::numeric_limits::quiet_NaN()}, + {0xfda3, std::numeric_limits::quiet_NaN()}, + {0xfe3e, std::numeric_limits::quiet_NaN()}, + {0xfe89, std::numeric_limits::quiet_NaN()}, + {0xfe91, std::numeric_limits::quiet_NaN()}, + {0xfe93, std::numeric_limits::quiet_NaN()}, + {0xfed1, std::numeric_limits::quiet_NaN()}, + {0xff7a, std::numeric_limits::quiet_NaN()}, + {0xffa3, std::numeric_limits::quiet_NaN()}, + }; + return result; +} + +TEST_CASE(check_half_values) +{ + for(auto [x, f] : half_lut()) + { + auto h = migraphx::bit_cast(x); + if(std::isnan(f)) + { + CHECK(std::isnan(h)); + } + else if(std::isinf(f)) + { + CHECK(std::isinf(h)); + CHECK((h < 0) == (f < 0)); + CHECK(bit_equal(x, migraphx::half(f))); + } + else + { + CHECK(bit_equal(x, migraphx::half(f))); + CHECK(migraphx::float_equal(float(h), f)); + } + } +} + +TEST_CASE(check_flows) +{ + // check positive underflow + CHECK(bit_equal(std::numeric_limits::min() * + std::numeric_limits::min(), + migraphx::half(0))); + + // check overflow + CHECK(bit_equal(std::numeric_limits::infinity() + + std::numeric_limits::infinity(), + std::numeric_limits::infinity())); + CHECK(bit_equal(std::numeric_limits::max() + + std::numeric_limits::max(), + std::numeric_limits::infinity())); + CHECK(bit_equal(std::numeric_limits::max() / + std::numeric_limits::epsilon(), + std::numeric_limits::infinity())); + CHECK(bit_equal(std::numeric_limits::max() + + std::numeric_limits::min(), + std::numeric_limits::max())); + + // check negative underflow + CHECK(bit_equal(std::numeric_limits::lowest() + + std::numeric_limits::lowest(), + -std::numeric_limits::infinity())); + CHECK(bit_equal(-std::numeric_limits::infinity() - + std::numeric_limits::infinity(), + -std::numeric_limits::infinity())); + CHECK(bit_equal(std::numeric_limits::lowest() - + std::numeric_limits::min(), + std::numeric_limits::lowest())); +} + +TEST_CASE(test_nan) +{ + float f_qnan = std::numeric_limits::quiet_NaN(); + migraphx::half half_qnan(f_qnan); + EXPECT(half_qnan.is_nan()); + EXPECT(std::isnan(half_qnan)); + + float f_snan = std::numeric_limits::signaling_NaN(); + migraphx::half half_snan(f_snan); + EXPECT(half_snan.is_nan()); + EXPECT(std::isnan(half_snan)); +} + +TEST_CASE(test_bool) +{ + float zero = 0.0; + float two = 2.0; + float other = -0.375; + migraphx::half half_zero(zero); + migraphx::half half_two(two); + migraphx::half half_other(other); + EXPECT(not static_cast(half_zero)); + EXPECT(static_cast(half_two)); + EXPECT(static_cast(half_other)); +} + +TEST_CASE(test_pos_infinity) +{ + float finf = std::numeric_limits::infinity(); + migraphx::half half_inf_1(finf); + CHECK(bit_equal(half_inf_1, std::numeric_limits::infinity())); +} + +TEST_CASE(test_neg_infinity) +{ + float finf = -1.0 * std::numeric_limits::infinity(); + migraphx::half half_neginf_1(finf); + CHECK(bit_equal(half_neginf_1, -std::numeric_limits::infinity())); +} + +TEST_CASE(test_numeric_max_1) +{ + float fmax = std::numeric_limits::max(); // fp32 max is fp16 inf + migraphx::half half_inf(fmax); + CHECK(bit_equal(half_inf, std::numeric_limits::infinity())); +} + +TEST_CASE(test_numeric_lowest_1) +{ + float flowest = std::numeric_limits::lowest(); + migraphx::half half_neginf(flowest); + CHECK(bit_equal(half_neginf, -std::numeric_limits::infinity())); +} + +TEST_CASE(test_max_eq_lowest) +{ + EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), + -1 * std::numeric_limits::max())); +} + +TEST_CASE(test_isfinite) +{ + EXPECT(std::isfinite(migraphx::half(0.0))); + EXPECT(std::isfinite(migraphx::half(-0.0))); + EXPECT(not std::isfinite(migraphx::half(std::numeric_limits::quiet_NaN()))); +} + +TEST_CASE(test_binary_ops) +{ + auto a = migraphx::half(-1.0); + auto b = migraphx::half(1.0); + auto c = migraphx::half(0.0); + auto d = migraphx::half(-0.0); + EXPECT(migraphx::float_equal((c + d), c)); + EXPECT(migraphx::float_equal((c + d), d)); + EXPECT(migraphx::float_equal((a + b), c)); + EXPECT(migraphx::float_equal((a + b), d)); + + auto e = migraphx::half(10.0); + auto f = migraphx::half(-10.0); + EXPECT(e > f); + EXPECT(f < e); + EXPECT(f <= e); + EXPECT(e >= f); + EXPECT(e <= e); + EXPECT(f >= f); + EXPECT(not migraphx::float_equal(f, e)); +} + +TEST_CASE(test_stream_op) +{ + auto a = migraphx::half(-1.0); + std::stringstream ss; + ss << a; + EXPECT(std::string("-1") == ss.str()); + ss = std::stringstream(); + auto b = std::numeric_limits::quiet_NaN(); + ss << b; + EXPECT(std::string("nan") == ss.str()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/include/layernorm.hpp b/test/include/layernorm.hpp index ed8cc008bd7..d800500345e 100644 --- a/test/include/layernorm.hpp +++ b/test/include/layernorm.hpp @@ -62,3 +62,43 @@ inline migraphx::instruction_ref add_layernorm(migraphx::module& m, m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), bias); return m.add_instruction(migraphx::make_op("add"), mul, bias_mbcast); } + +inline migraphx::instruction_ref add_pointwise_layernorm(migraphx::module& m, + migraphx::instruction_ref x, + const std::vector& dims, + float eps = 1e-12f) +{ + auto mgx_type = x->get_shape().type(); + auto scale = m.add_parameter("scale", migraphx::shape{mgx_type, {1, 1, dims.back()}}); + auto bias = m.add_parameter("bias", migraphx::shape{mgx_type, {1, 1, dims.back()}}); + + auto epsilon = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {eps}}); + auto one = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {1}}); + + auto mean = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), x); + auto mean_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean); + auto x_minus_mean = m.add_instruction(migraphx::make_op("sub"), x, mean_mbcast); + auto sqdiff = m.add_instruction(migraphx::make_op("sqdiff"), x, mean_mbcast); + auto var = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), sqdiff); + + auto epsilon_mbcast = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", var->get_shape().lens()}}), epsilon); + auto var_stable = m.add_instruction(migraphx::make_op("add"), var, epsilon_mbcast); + auto inv_stddev_x = m.add_instruction(migraphx::make_op("rsqrt"), var_stable); + auto inv_stddev_x_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), inv_stddev_x); + auto norm = m.add_instruction(migraphx::make_op("mul"), x_minus_mean, inv_stddev_x_mbcast); + + auto one_mbcast = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", scale->get_shape().lens()}}), one); + auto add_scale = m.add_instruction(migraphx::make_op("add"), scale, one_mbcast); + auto scale_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), add_scale); + auto scale_norm = m.add_instruction(migraphx::make_op("mul"), norm, scale_mbcast); + + auto bias_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), bias); + + return m.add_instruction(migraphx::make_op("add"), scale_norm, bias_mbcast); +} diff --git a/test/include/test.hpp b/test/include/test.hpp index b8740e82ddb..c0112bb8e57 100644 --- a/test/include/test.hpp +++ b/test/include/test.hpp @@ -229,7 +229,15 @@ struct lhs_expression std::string op = Operator::as_string(); if(not op.empty()) s << Operator::as_string() << " "; - s << self.lhs; + if constexpr(std::is_pointer_v) + { + s << static_cast(self.lhs); + } + else + { + // NOLINTNEXTLINE + s << self.lhs; + } return s; } diff --git a/test/instruction.cpp b/test/instruction.cpp index 134658e336b..0ee22e13553 100644 --- a/test/instruction.cpp +++ b/test/instruction.cpp @@ -67,4 +67,24 @@ TEST_CASE(check_replace_shape) EXPECT(add->get_shape() == r); } +TEST_CASE(check_replace_dag) +{ + migraphx::module m; + migraphx::shape s{migraphx::shape::float_type, {3, 2}}; + auto input = m.add_parameter("x", s); + auto reduce = m.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), input); + auto abs = m.add_instruction(migraphx::make_op("abs"), reduce); + auto sin = m.add_instruction(migraphx::make_op("sin"), reduce); + auto add = m.add_instruction(migraphx::make_op("add"), abs, sin); + auto add2 = m.add_instruction(migraphx::make_op("add"), add, reduce); + + reduce->replace(migraphx::make_op("reduce_sum", {{"axes", {1}}})); + + migraphx::shape r{migraphx::shape::float_type, {3, 1}}; + EXPECT(reduce->get_shape() == r); + EXPECT(abs->get_shape() == r); + EXPECT(sin->get_shape() == r); + EXPECT(add->get_shape() == r); + EXPECT(add2->get_shape() == r); +} int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/layout_nhwc.cpp b/test/layout_convolution.cpp similarity index 58% rename from test/layout_nhwc.cpp rename to test/layout_convolution.cpp index 7dae574d113..64e8830d67b 100644 --- a/test/layout_nhwc.cpp +++ b/test/layout_convolution.cpp @@ -21,7 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include +#include #include #include #include @@ -32,9 +32,9 @@ #include -void run_pass(migraphx::module& m) +void run_pass(migraphx::module& m, migraphx::layout_convolution lc = {}) { - migraphx::run_passes(m, {migraphx::layout_nhwc{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(m, {lc, migraphx::dead_code_elimination{}}); } migraphx::operation layout(std::vector permutation = {0, 1, 2, 3}) @@ -47,7 +47,7 @@ migraphx::instruction_ref add_layout_nhwc(migraphx::module& m, migraphx::instruc return m.add_instruction(layout({0, 2, 3, 1}), ins); } -TEST_CASE(conv_relu) +TEST_CASE(auto_conv_nchw) { migraphx::module m1; { @@ -59,9 +59,128 @@ TEST_CASE(conv_relu) {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), x, w); - m1.add_instruction(migraphx::make_op("relu"), conv); + auto relu = m1.add_instruction(migraphx::make_op("relu"), conv); + m1.add_return({relu}); } + migraphx::module m2 = m1; run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(auto_conv_nhwc) +{ + auto transpose = migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}); + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 16, 16, 8}}); + auto xtranspose = m1.add_instruction(transpose, x); + auto w = m1.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {16, 3, 3, 8}})); + auto wtranspose = m1.add_instruction(transpose, w); + auto conv = m1.add_instruction( + migraphx::make_op("convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + xtranspose, + wtranspose); + auto relu = m1.add_instruction(migraphx::make_op("relu"), conv); + m1.add_return({relu}); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(auto_conv_mixed) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}); + auto w = m1.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {3, 3, 16, 8}})); + auto wtranspose = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 3, 0, 1}}}), w); + auto conv = m1.add_instruction( + migraphx::make_op("convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + wtranspose); + auto relu = m1.add_instruction(migraphx::make_op("relu"), conv); + m1.add_return({relu}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}); + auto w = m2.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {3, 3, 16, 8}})); + auto wtranspose = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 3, 0, 1}}}), w); + auto wlayout = m2.add_instruction( + migraphx::make_op("layout", {{"permutation", {0, 1, 2, 3}}}), wtranspose); + auto conv = m2.add_instruction( + migraphx::make_op("convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + wlayout); + auto relu = m2.add_instruction(migraphx::make_op("relu"), conv); + m2.add_return({relu}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(auto_quant_conv_mixed) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int8_type, {1, 8, 16, 16}}); + auto w = + m1.add_literal(migraphx::generate_literal({migraphx::shape::int8_type, {3, 3, 16, 8}})); + auto wtranspose = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 3, 0, 1}}}), w); + auto conv = m1.add_instruction( + migraphx::make_op("quant_convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + wtranspose); + auto relu = m1.add_instruction(migraphx::make_op("relu"), conv); + m1.add_return({relu}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int8_type, {1, 8, 16, 16}}); + auto w = + m2.add_literal(migraphx::generate_literal({migraphx::shape::int8_type, {3, 3, 16, 8}})); + auto wtranspose = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 3, 0, 1}}}), w); + auto wlayout = m2.add_instruction( + migraphx::make_op("layout", {{"permutation", {0, 1, 2, 3}}}), wtranspose); + auto conv = m2.add_instruction( + migraphx::make_op("quant_convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + wlayout); + auto relu = m2.add_instruction(migraphx::make_op("relu"), conv); + m2.add_return({relu}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(nhwc_conv_relu) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}); + auto w = m1.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {16, 8, 3, 3}})); + auto conv = m1.add_instruction( + migraphx::make_op("convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + w); + m1.add_instruction(migraphx::make_op("relu"), conv); + } + run_pass(m1, {.channels_last = true}); migraphx::module m2; { @@ -81,7 +200,7 @@ TEST_CASE(conv_relu) EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(conv_add) +TEST_CASE(nhwc_conv_add) { migraphx::module m1; { @@ -99,7 +218,7 @@ TEST_CASE(conv_add) y); m1.add_instruction(migraphx::make_op("add"), conv, b); } - run_pass(m1); + run_pass(m1, {.channels_last = true}); migraphx::module m2; { @@ -114,7 +233,7 @@ TEST_CASE(conv_add) {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), x, w); - auto b = m2.add_instruction( + auto b = m2.add_instruction( migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}), y); auto add = m2.add_instruction(migraphx::make_op("add"), conv, b); @@ -123,7 +242,49 @@ TEST_CASE(conv_add) EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(conv_conv) +TEST_CASE(nhwc_quant_conv_add) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int8_type, {1, 8, 16, 16}}); + auto w = + m1.add_literal(migraphx::generate_literal({migraphx::shape::int8_type, {16, 8, 3, 3}})); + auto y = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {16}})); + auto conv = m1.add_instruction( + migraphx::make_op("quant_convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + w); + auto b = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}), + y); + m1.add_instruction(migraphx::make_op("add"), conv, b); + } + run_pass(m1, {.channels_last = true}); + + migraphx::module m2; + { + auto x = add_layout_nhwc( + m2, m2.add_parameter("x", {migraphx::shape::int8_type, {1, 8, 16, 16}})); + auto w = add_layout_nhwc(m2, + m2.add_literal(migraphx::generate_literal( + {migraphx::shape::int8_type, {16, 8, 3, 3}}))); + auto y = m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {16}})); + auto conv = m2.add_instruction( + migraphx::make_op("quant_convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + w); + auto b = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}), + y); + auto add = m2.add_instruction(migraphx::make_op("add"), conv, b); + m2.add_instruction(layout(), add); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(nhwc_conv_conv) { migraphx::module m1; { @@ -149,7 +310,7 @@ TEST_CASE(conv_conv) auto relu2 = m1.add_instruction(migraphx::make_op("relu"), add2); m1.add_return({relu2}); } - run_pass(m1); + run_pass(m1, {.channels_last = true}); migraphx::module m2; { @@ -182,7 +343,7 @@ TEST_CASE(conv_conv) EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(conv_reduce) +TEST_CASE(nhwc_conv_reduce) { migraphx::module m1; { @@ -201,7 +362,7 @@ TEST_CASE(conv_reduce) auto squeeze = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), reduce); m1.add_return({squeeze}); } - run_pass(m1); + run_pass(m1, {.channels_last = true}); migraphx::module m2; { diff --git a/test/onnx/.onnxrt-commit b/test/onnx/.onnxrt-commit index d4bb74f680a..75454c00440 100644 --- a/test/onnx/.onnxrt-commit +++ b/test/onnx/.onnxrt-commit @@ -1 +1 @@ -7964d3aef6038ea82b0982ec5a520b5708c8a136 +62e7e24f172a062242acae11575f7ea11529dd09 diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 05713398294..c46ce85d080 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -7465,6 +7465,348 @@ def matmulinteger_int8_uint8_dual_zero_zp_test(): return ([node], [m1, m2], [y], [zp1, zp2]) +@onnx_test() +def matmulintegertofloat_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.INT8, [3, 2]) + s1 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [3]) + s2 = helper.make_tensor_value_info('4', TensorProto.FLOAT, [2]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1, s2], [y], []) + + +@onnx_test() +def matmulintegertofloat_zp_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.INT8, [3, 2]) + s1 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [3]) + s2 = helper.make_tensor_value_info('4', TensorProto.FLOAT, [2]) + zp1 = helper.make_tensor_value_info('5', TensorProto.INT8, [3]) + zp2 = helper.make_tensor_value_info('6', TensorProto.INT8, [2]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4', '5', '6'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1, s2, zp1, zp2], [y], []) + + +@onnx_test() +def matmulintegertofloat_scalar_zp_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.INT8, [3, 2]) + s1 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [3]) + s2 = helper.make_tensor_value_info('4', TensorProto.FLOAT, [2]) + zp1 = helper.make_tensor_value_info('5', TensorProto.INT8, [3]) + zp2 = helper.make_tensor('6', TensorProto.INT8, [], [129]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4', '5', '6'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1, s2, zp1], [y], [zp2]) + + +@onnx_test() +def matmulintegertofloat_scalar_scale_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.INT8, [3, 2]) + s1 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [3]) + s2 = helper.make_tensor('4', TensorProto.FLOAT, [], [10]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1], [y], [s2]) + + +@onnx_test() +def matmulintegertofloat_zp_bias_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2]) + s1 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [3]) + s2 = helper.make_tensor_value_info('4', TensorProto.FLOAT, [2]) + zp1 = helper.make_tensor_value_info('5', TensorProto.INT8, [3]) + zp2 = helper.make_tensor_value_info('6', TensorProto.UINT8, [2]) + b1 = helper.make_tensor_value_info('7', TensorProto.FLOAT, [2]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4', '5', '6', '7'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1, s2, zp1, zp2, b1], [y], []) + + +@onnx_test() +def matmulintegertofloat_bad_scale_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.INT8, [3, 2]) + s1 = helper.make_tensor_value_info('3', TensorProto.INT8, [4, 3]) + s2 = helper.make_tensor_value_info('4', TensorProto.FLOAT16, [3, 2]) + zp1 = helper.make_tensor('5', TensorProto.INT8, [], [0]) + zp2 = helper.make_tensor('6', TensorProto.UINT8, [], [128]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4', '5', '6'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1, s2], [y], [zp1, zp2]) + + +@onnx_test() +def matmulintegertofloat_bad_scale2_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.INT8, [3, 2]) + s1 = helper.make_tensor_value_info('3', TensorProto.FLOAT16, [4, 3]) + s2 = helper.make_tensor_value_info('4', TensorProto.INT8, [3, 2]) + zp1 = helper.make_tensor('5', TensorProto.INT8, [], [0]) + zp2 = helper.make_tensor('6', TensorProto.UINT8, [], [128]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4', '5', '6'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1, s2], [y], [zp1, zp2]) + + +@onnx_test() +def matmulintegertofloat_bad_scale3_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2]) + s1 = helper.make_tensor_value_info('3', TensorProto.FLOAT16, [4, 3]) + s2 = helper.make_tensor_value_info('4', TensorProto.FLOAT, [3, 2]) + zp1 = helper.make_tensor('5', TensorProto.INT8, [], [0]) + zp2 = helper.make_tensor('6', TensorProto.UINT8, [], [128]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4', '5', '6'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1, s2], [y], [zp1, zp2]) + + +@onnx_test() +def matmulintegertofloat_bad_scale4_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2]) + s1 = helper.make_tensor_value_info('3', TensorProto.INT8, [3]) + s2 = helper.make_tensor_value_info('4', TensorProto.UINT8, [2, 2]) + zp1 = helper.make_tensor('5', TensorProto.INT8, [], [0]) + zp2 = helper.make_tensor('6', TensorProto.UINT8, [], [128]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4', '5', '6'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1, s2], [y], [zp1, zp2]) + + +@onnx_test() +def matmulintegertofloat_bad_scale5_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2]) + s1 = helper.make_tensor_value_info('3', TensorProto.INT8, [3]) + s2 = helper.make_tensor_value_info('4', TensorProto.UINT8, [7]) + zp1 = helper.make_tensor('5', TensorProto.INT8, [], [0]) + zp2 = helper.make_tensor('6', TensorProto.UINT8, [], [128]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4', '5', '6'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1, s2], [y], [zp1, zp2]) + + +@onnx_test() +def matmulintegertofloat_bad_bias_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2]) + s1 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [4, 3]) + s2 = helper.make_tensor_value_info('4', TensorProto.FLOAT, [3, 2]) + zp1 = helper.make_tensor('5', TensorProto.INT8, [], [0]) + zp2 = helper.make_tensor('6', TensorProto.UINT8, [], [128]) + b1 = helper.make_tensor('7', TensorProto.UINT8, [2], [128, 128]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4', '5', '6', '7'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1, s2], [y], [zp1, zp2, b1]) + + +@onnx_test() +def matmulintegertofloat_bad_bias_test2(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2]) + s1 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [4, 3]) + s2 = helper.make_tensor_value_info('4', TensorProto.FLOAT, [3, 2]) + zp1 = helper.make_tensor('5', TensorProto.INT8, [], [0]) + zp2 = helper.make_tensor('6', TensorProto.UINT8, [], [128]) + b1 = helper.make_tensor('7', TensorProto.FLOAT16, [2], [128, -128]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4', '5', '6', '7'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1, s2], [y], [zp1, zp2, b1]) + + +@onnx_test() +def matmulintegertofloat_bad_bias_test3(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2]) + s1 = helper.make_tensor_value_info('3', TensorProto.FLOAT16, [4, 3]) + s2 = helper.make_tensor_value_info('4', TensorProto.FLOAT, [3, 2]) + zp1 = helper.make_tensor('5', TensorProto.INT8, [], [0]) + zp2 = helper.make_tensor('6', TensorProto.UINT8, [], [128]) + b1 = helper.make_tensor('7', TensorProto.FLOAT16, [], [128]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4', '5', '6', '7'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1, s2], [y], [zp1, zp2, b1]) + + +@onnx_test() +def matmulintegertofloat_half_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2]) + s1 = helper.make_tensor_value_info('3', TensorProto.FLOAT16, [3]) + s2 = helper.make_tensor_value_info('4', TensorProto.FLOAT16, [2]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1, s2], [y], []) + + +@onnx_test() +def matmulintegertofloat_half_zp_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2]) + s1 = helper.make_tensor_value_info('3', TensorProto.FLOAT16, [3]) + s2 = helper.make_tensor_value_info('4', TensorProto.FLOAT16, [2]) + zp1 = helper.make_tensor_value_info('5', TensorProto.INT8, [3]) + zp2 = helper.make_tensor_value_info('6', TensorProto.UINT8, [2]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4', '5', '6'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1, s2, zp1, zp2], [y], []) + + +@onnx_test() +def matmulintegertofloat_half_scalar_zp_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2]) + s1 = helper.make_tensor_value_info('3', TensorProto.FLOAT16, [3]) + s2 = helper.make_tensor_value_info('4', TensorProto.FLOAT16, [2]) + zp1 = helper.make_tensor('5', TensorProto.INT8, [], [0]) + zp2 = helper.make_tensor('6', TensorProto.UINT8, [], [128]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4', '5', '6'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1, s2], [y], [zp1, zp2]) + + +@onnx_test() +def matmulintegertofloat_half_zp_bias_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2]) + s1 = helper.make_tensor_value_info('3', TensorProto.FLOAT16, [3]) + s2 = helper.make_tensor_value_info('4', TensorProto.FLOAT16, [2]) + zp1 = helper.make_tensor('5', TensorProto.INT8, [], [0]) + zp2 = helper.make_tensor('6', TensorProto.UINT8, [], [128]) + b1 = helper.make_tensor('7', TensorProto.FLOAT16, [2], [128, -128]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4', '5', '6', '7'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1, s2], [y], [zp1, zp2, b1]) + + +@onnx_test() +def matmulintegertofloat_zp_bias_3d_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3, 2]) + m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [4, 2, 3]) + s1 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [2]) + s2 = helper.make_tensor_value_info('4', TensorProto.FLOAT, [3]) + zp1 = helper.make_tensor_value_info('5', TensorProto.INT8, [2]) + zp2 = helper.make_tensor_value_info('6', TensorProto.UINT8, [3]) + b1 = helper.make_tensor_value_info('7', TensorProto.FLOAT, [3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 3, 3]) + + node = onnx.helper.make_node( + 'MatMulIntegerToFloat', + inputs=['1', '2', '3', '4', '5', '6', '7'], + outputs=['y'], + ) + + return ([node], [m1, m2, s1, s2, zp1, zp2, b1], [y]) + + @onnx_test() def max_test(): a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) diff --git a/test/onnx/matmulintegertofloat_bad_bias_test.onnx b/test/onnx/matmulintegertofloat_bad_bias_test.onnx new file mode 100644 index 00000000000..f733a4642b7 Binary files /dev/null and b/test/onnx/matmulintegertofloat_bad_bias_test.onnx differ diff --git a/test/onnx/matmulintegertofloat_bad_bias_test2.onnx b/test/onnx/matmulintegertofloat_bad_bias_test2.onnx new file mode 100644 index 00000000000..cd0379b482f Binary files /dev/null and b/test/onnx/matmulintegertofloat_bad_bias_test2.onnx differ diff --git a/test/onnx/matmulintegertofloat_bad_bias_test3.onnx b/test/onnx/matmulintegertofloat_bad_bias_test3.onnx new file mode 100644 index 00000000000..67a09539612 Binary files /dev/null and b/test/onnx/matmulintegertofloat_bad_bias_test3.onnx differ diff --git a/test/onnx/matmulintegertofloat_bad_scale2_test.onnx b/test/onnx/matmulintegertofloat_bad_scale2_test.onnx new file mode 100644 index 00000000000..d900df4348b Binary files /dev/null and b/test/onnx/matmulintegertofloat_bad_scale2_test.onnx differ diff --git a/test/onnx/matmulintegertofloat_bad_scale3_test.onnx b/test/onnx/matmulintegertofloat_bad_scale3_test.onnx new file mode 100644 index 00000000000..10264b6f740 Binary files /dev/null and b/test/onnx/matmulintegertofloat_bad_scale3_test.onnx differ diff --git a/test/onnx/matmulintegertofloat_bad_scale4_test.onnx b/test/onnx/matmulintegertofloat_bad_scale4_test.onnx new file mode 100644 index 00000000000..9c1867cc42f Binary files /dev/null and b/test/onnx/matmulintegertofloat_bad_scale4_test.onnx differ diff --git a/test/onnx/matmulintegertofloat_bad_scale5_test.onnx b/test/onnx/matmulintegertofloat_bad_scale5_test.onnx new file mode 100644 index 00000000000..718a4404b59 Binary files /dev/null and b/test/onnx/matmulintegertofloat_bad_scale5_test.onnx differ diff --git a/test/onnx/matmulintegertofloat_bad_scale_test.onnx b/test/onnx/matmulintegertofloat_bad_scale_test.onnx new file mode 100644 index 00000000000..70599385980 Binary files /dev/null and b/test/onnx/matmulintegertofloat_bad_scale_test.onnx differ diff --git a/test/onnx/matmulintegertofloat_half_test.onnx b/test/onnx/matmulintegertofloat_half_test.onnx new file mode 100644 index 00000000000..aba774ea814 --- /dev/null +++ b/test/onnx/matmulintegertofloat_half_test.onnx @@ -0,0 +1,29 @@ + +matmulintegertofloat_half_test:¨ +% +1 +2 +3 +4y"MatMulIntegerToFloatmatmulintegertofloat_half_testZ +1 +  + +Z +2 +  + +Z +3 + + + +Z +4 + + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/matmulintegertofloat_half_zp_bias_test.onnx b/test/onnx/matmulintegertofloat_half_zp_bias_test.onnx new file mode 100644 index 00000000000..668c596f904 Binary files /dev/null and b/test/onnx/matmulintegertofloat_half_zp_bias_test.onnx differ diff --git a/test/onnx/matmulintegertofloat_half_zp_test.onnx b/test/onnx/matmulintegertofloat_half_zp_test.onnx new file mode 100644 index 00000000000..c2d7cf0b8a1 --- /dev/null +++ b/test/onnx/matmulintegertofloat_half_zp_test.onnx @@ -0,0 +1,39 @@ + +!matmulintegertofloat_half_zp_test:Ó ++ +1 +2 +3 +4 +5 +6y"MatMulIntegerToFloat!matmulintegertofloat_half_zp_testZ +1 +  + +Z +2 +  + +Z +3 + + + +Z +4 + + + +Z +5 + + +Z +6 + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/matmulintegertofloat_scalar_scale_test.onnx b/test/onnx/matmulintegertofloat_scalar_scale_test.onnx new file mode 100644 index 00000000000..3aec4245278 Binary files /dev/null and b/test/onnx/matmulintegertofloat_scalar_scale_test.onnx differ diff --git a/test/onnx/matmulintegertofloat_scalar_zp_test.onnx b/test/onnx/matmulintegertofloat_scalar_zp_test.onnx new file mode 100644 index 00000000000..94895811211 --- /dev/null +++ b/test/onnx/matmulintegertofloat_scalar_zp_test.onnx @@ -0,0 +1,33 @@ + +#matmulintegertofloat_scalar_zp_test:Ï ++ +1 +2 +3 +4 +5 +6y"MatMulIntegerToFloat#matmulintegertofloat_scalar_zp_test* *B6Z +1 +  + +Z +2 +  + +Z +3 + + +Z +4 + + +Z +5 + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/matmulintegertofloat_test.onnx b/test/onnx/matmulintegertofloat_test.onnx new file mode 100644 index 00000000000..64bcb9f9160 --- /dev/null +++ b/test/onnx/matmulintegertofloat_test.onnx @@ -0,0 +1,27 @@ + +matmulintegertofloat_test:£ +% +1 +2 +3 +4y"MatMulIntegerToFloatmatmulintegertofloat_testZ +1 +  + +Z +2 +  + +Z +3 + + +Z +4 + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/matmulintegertofloat_zp_bias_3d_test.onnx b/test/onnx/matmulintegertofloat_zp_bias_3d_test.onnx new file mode 100644 index 00000000000..9b86cf3cf57 --- /dev/null +++ b/test/onnx/matmulintegertofloat_zp_bias_3d_test.onnx @@ -0,0 +1,45 @@ + +$matmulintegertofloat_zp_bias_3d_test:ö +. +1 +2 +3 +4 +5 +6 +7y"MatMulIntegerToFloat$matmulintegertofloat_zp_bias_3d_testZ +1 + + + +Z +2 + + + +Z +3 + + +Z +4 + + +Z +5 + + +Z +6 + + +Z +7 + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/matmulintegertofloat_zp_bias_test.onnx b/test/onnx/matmulintegertofloat_zp_bias_test.onnx new file mode 100644 index 00000000000..cb73467cc9d --- /dev/null +++ b/test/onnx/matmulintegertofloat_zp_bias_test.onnx @@ -0,0 +1,42 @@ + +!matmulintegertofloat_zp_bias_test:ç +. +1 +2 +3 +4 +5 +6 +7y"MatMulIntegerToFloat!matmulintegertofloat_zp_bias_testZ +1 +  + +Z +2 +  + +Z +3 + + +Z +4 + + +Z +5 + + +Z +6 + + +Z +7 + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/matmulintegertofloat_zp_test.onnx b/test/onnx/matmulintegertofloat_zp_test.onnx new file mode 100644 index 00000000000..62bf77fe49b --- /dev/null +++ b/test/onnx/matmulintegertofloat_zp_test.onnx @@ -0,0 +1,37 @@ + +matmulintegertofloat_zp_test:Î ++ +1 +2 +3 +4 +5 +6y"MatMulIntegerToFloatmatmulintegertofloat_zp_testZ +1 +  + +Z +2 +  + +Z +3 + + +Z +4 + + +Z +5 + + +Z +6 + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/parse/matmulintegerToFloat_bad_bias_test.cpp b/test/onnx/parse/matmulintegerToFloat_bad_bias_test.cpp new file mode 100644 index 00000000000..9527bf8fd03 --- /dev/null +++ b/test/onnx/parse/matmulintegerToFloat_bad_bias_test.cpp @@ -0,0 +1,30 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulintegertofloat_bad_boas_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("matmulintegertofloat_bad_bias_test.onnx"); })); +} diff --git a/test/onnx/parse/matmulintegerToFloat_bad_bias_test2.cpp b/test/onnx/parse/matmulintegerToFloat_bad_bias_test2.cpp new file mode 100644 index 00000000000..c7fe8161666 --- /dev/null +++ b/test/onnx/parse/matmulintegerToFloat_bad_bias_test2.cpp @@ -0,0 +1,30 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulintegertofloat_bad_bias_test2) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("matmulintegertofloat_bad_boas_test2.onnx"); })); +} diff --git a/test/onnx/parse/matmulintegerToFloat_bad_bias_test3.cpp b/test/onnx/parse/matmulintegerToFloat_bad_bias_test3.cpp new file mode 100644 index 00000000000..becd4dba1a8 --- /dev/null +++ b/test/onnx/parse/matmulintegerToFloat_bad_bias_test3.cpp @@ -0,0 +1,30 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulintegertofloat_bad_bias_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("matmulintegertofloat_bad_bias_test3.onnx"); })); +} diff --git a/test/onnx/parse/matmulintegerToFloat_bad_scale2_test.cpp b/test/onnx/parse/matmulintegerToFloat_bad_scale2_test.cpp new file mode 100644 index 00000000000..f8505c7af7e --- /dev/null +++ b/test/onnx/parse/matmulintegerToFloat_bad_scale2_test.cpp @@ -0,0 +1,31 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulintegertofloat_bad_scale2_test) +{ + EXPECT( + test::throws([&] { migraphx::parse_onnx("matmulintegertofloat_bad_scale3_test.onnx"); })); +} diff --git a/test/onnx/parse/matmulintegerToFloat_bad_scale3_test.cpp b/test/onnx/parse/matmulintegerToFloat_bad_scale3_test.cpp new file mode 100644 index 00000000000..f286e875322 --- /dev/null +++ b/test/onnx/parse/matmulintegerToFloat_bad_scale3_test.cpp @@ -0,0 +1,31 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulintegertofloat_bad_scale3_test) +{ + EXPECT( + test::throws([&] { migraphx::parse_onnx("matmulintegertofloat_bad_scale3_test.onnx"); })); +} diff --git a/test/onnx/parse/matmulintegerToFloat_bad_scale4_test.cpp b/test/onnx/parse/matmulintegerToFloat_bad_scale4_test.cpp new file mode 100644 index 00000000000..bc8370587a2 --- /dev/null +++ b/test/onnx/parse/matmulintegerToFloat_bad_scale4_test.cpp @@ -0,0 +1,31 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulintegertofloat_bad_scale4_test) +{ + EXPECT( + test::throws([&] { migraphx::parse_onnx("matmulintegertofloat_bad_scale4_test.onnx"); })); +} diff --git a/test/onnx/parse/matmulintegerToFloat_bad_scale5_test.cpp b/test/onnx/parse/matmulintegerToFloat_bad_scale5_test.cpp new file mode 100644 index 00000000000..127fc771197 --- /dev/null +++ b/test/onnx/parse/matmulintegerToFloat_bad_scale5_test.cpp @@ -0,0 +1,31 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulintegertofloat_bad_scale5_test) +{ + EXPECT( + test::throws([&] { migraphx::parse_onnx("matmulintegertofloat_bad_scale5_test.onnx"); })); +} diff --git a/test/onnx/parse/matmulintegerToFloat_bad_scale_test.cpp b/test/onnx/parse/matmulintegerToFloat_bad_scale_test.cpp new file mode 100644 index 00000000000..46663319192 --- /dev/null +++ b/test/onnx/parse/matmulintegerToFloat_bad_scale_test.cpp @@ -0,0 +1,30 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulintegertofloat_bad_scale_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("matmulintegertofloat_bad_scale_test.onnx"); })); +} diff --git a/test/onnx/parse/matmulinteger_dual_zp_test.cpp b/test/onnx/parse/matmulinteger_dual_zp_test.cpp index d7b569a37c2..8f434a27182 100644 --- a/test/onnx/parse/matmulinteger_dual_zp_test.cpp +++ b/test/onnx/parse/matmulinteger_dual_zp_test.cpp @@ -79,5 +79,5 @@ TEST_CASE(matmulinteger_dual_zp_test) auto prog = optimize_onnx("matmulinteger_int8_uint8_dual_zp_test.onnx"); - EXPECT(p == prog); + EXPECT(p.sort() == prog.sort()); } diff --git a/test/onnx/parse/matmulinteger_one_zp_test.cpp b/test/onnx/parse/matmulinteger_one_zp_test.cpp index 4e34182fd4a..f35c28fe0f7 100644 --- a/test/onnx/parse/matmulinteger_one_zp_test.cpp +++ b/test/onnx/parse/matmulinteger_one_zp_test.cpp @@ -60,5 +60,5 @@ TEST_CASE(matmulinteger_one_zp_test) auto prog = optimize_onnx("matmulinteger_int8_uint8_one_zp_test.onnx"); - EXPECT(p == prog); + EXPECT(p.sort() == prog.sort()); } diff --git a/test/onnx/parse/matmulintegertofloat_half_test.cpp b/test/onnx/parse/matmulintegertofloat_half_test.cpp new file mode 100644 index 00000000000..4cce1f18cfa --- /dev/null +++ b/test/onnx/parse/matmulintegertofloat_half_test.cpp @@ -0,0 +1,54 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulintegertofloat_half_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int8_type, {4, 3}}); + auto x1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::uint8_type, {3, 2}}); + auto scale_x0 = mm->add_parameter("3", migraphx::shape{migraphx::shape::half_type, {3}}); + auto scale_x1 = mm->add_parameter("4", migraphx::shape{migraphx::shape::half_type, {2}}); + + auto sq_scale_x0 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), scale_x0); + auto sq_scale_x1 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), scale_x1); + + auto bc_scale_x0 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x0->get_shape().lens()}}), sq_scale_x0); + auto r0 = mm->add_instruction(migraphx::make_op("dequantizelinear"), x0, bc_scale_x0); + + auto bc_scale_x1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x1->get_shape().lens()}}), sq_scale_x1); + + auto r1 = mm->add_instruction(migraphx::make_op("dequantizelinear"), x1, bc_scale_x1); + mm->add_instruction(migraphx::make_op("dot"), r0, r1); + + auto prog = optimize_onnx("matmulintegertofloat_half_test.onnx"); + + EXPECT(p == prog); +} diff --git a/test/onnx/parse/matmulintegertofloat_half_zp_test.cpp b/test/onnx/parse/matmulintegertofloat_half_zp_test.cpp new file mode 100644 index 00000000000..679ae1b5806 --- /dev/null +++ b/test/onnx/parse/matmulintegertofloat_half_zp_test.cpp @@ -0,0 +1,64 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulintegertofloat_half_zp_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int8_type, {4, 3}}); + auto x1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::uint8_type, {3, 2}}); + auto scale_x0 = mm->add_parameter("3", migraphx::shape{migraphx::shape::half_type, {3}}); + auto scale_x1 = mm->add_parameter("4", migraphx::shape{migraphx::shape::half_type, {2}}); + auto zp_x0 = mm->add_parameter("5", migraphx::shape{migraphx::shape::int8_type, {3}}); + auto zp_x1 = mm->add_parameter("6", migraphx::shape{migraphx::shape::uint8_type, {2}}); + + auto sq_scale_x0 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), scale_x0); + auto sq_scale_x1 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), scale_x1); + auto sq_zp_x0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), zp_x0); + auto sq_zp_x1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), zp_x1); + + auto bc_scale_x0 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x0->get_shape().lens()}}), sq_scale_x0); + auto bc_zp_x0 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x0->get_shape().lens()}}), sq_zp_x0); + + auto r0 = mm->add_instruction(migraphx::make_op("dequantizelinear"), x0, bc_scale_x0, bc_zp_x0); + + auto bc_scale_x1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x1->get_shape().lens()}}), sq_scale_x1); + + auto bc_zp_x1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x1->get_shape().lens()}}), sq_zp_x1); + + auto r1 = mm->add_instruction(migraphx::make_op("dequantizelinear"), x1, bc_scale_x1, bc_zp_x1); + mm->add_instruction(migraphx::make_op("dot"), r0, r1); + + auto prog = optimize_onnx("matmulintegertofloat_half_zp_test.onnx"); + + EXPECT(p == prog); +} diff --git a/test/onnx/parse/matmulintegertofloat_scalar_scale_test.cpp b/test/onnx/parse/matmulintegertofloat_scalar_scale_test.cpp new file mode 100644 index 00000000000..75ce8ce0d76 --- /dev/null +++ b/test/onnx/parse/matmulintegertofloat_scalar_scale_test.cpp @@ -0,0 +1,58 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulintegertofloat_scalar_scale_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto scale_x1 = mm->add_literal( + migraphx::literal(migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {10})); + auto x0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int8_type, {4, 3}}); + auto x1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::int8_type, {3, 2}}); + auto scale_x0 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {3}}); + + auto sq_scale_x0 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), scale_x0); + + auto sq_scale_x1 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), scale_x1); + + sq_scale_x1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), sq_scale_x1); + + auto bc_scale_x0 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x0->get_shape().lens()}}), sq_scale_x0); + auto r0 = mm->add_instruction(migraphx::make_op("dequantizelinear"), x0, bc_scale_x0); + + auto bc_scale_x1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x1->get_shape().lens()}}), sq_scale_x1); + + auto r1 = mm->add_instruction(migraphx::make_op("dequantizelinear"), x1, bc_scale_x1); + mm->add_instruction(migraphx::make_op("dot"), r0, r1); + + auto prog = optimize_onnx("matmulintegertofloat_scalar_scale_test.onnx"); + + EXPECT(p == prog); +} diff --git a/test/onnx/parse/matmulintegertofloat_scalar_zp_test.cpp b/test/onnx/parse/matmulintegertofloat_scalar_zp_test.cpp new file mode 100644 index 00000000000..cb2979df010 --- /dev/null +++ b/test/onnx/parse/matmulintegertofloat_scalar_zp_test.cpp @@ -0,0 +1,66 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulintegertofloat_scalar_zp_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto zp_x1 = mm->add_literal( + migraphx::literal(migraphx::shape{migraphx::shape::int8_type, {1}, {0}}, {129})); + auto x0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int8_type, {4, 3}}); + auto x1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::int8_type, {3, 2}}); + auto scale_x0 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {3}}); + auto scale_x1 = mm->add_parameter("4", migraphx::shape{migraphx::shape::float_type, {2}}); + auto zp_x0 = mm->add_parameter("5", migraphx::shape{migraphx::shape::int8_type, {3}}); + + auto sq_scale_x0 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), scale_x0); + auto sq_scale_x1 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), scale_x1); + auto sq_zp_x0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), zp_x0); + auto sq_zp_x1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), zp_x1); + sq_zp_x1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), sq_zp_x1); + + auto bc_scale_x0 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x0->get_shape().lens()}}), sq_scale_x0); + auto bc_zp_x0 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x0->get_shape().lens()}}), sq_zp_x0); + + auto r0 = mm->add_instruction(migraphx::make_op("dequantizelinear"), x0, bc_scale_x0, bc_zp_x0); + + auto bc_scale_x1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x1->get_shape().lens()}}), sq_scale_x1); + + auto bc_zp_x1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x1->get_shape().lens()}}), sq_zp_x1); + + auto r1 = mm->add_instruction(migraphx::make_op("dequantizelinear"), x1, bc_scale_x1, bc_zp_x1); + mm->add_instruction(migraphx::make_op("dot"), r0, r1); + + auto prog = optimize_onnx("matmulintegertofloat_scalar_zp_test.onnx"); + + EXPECT(p == prog); +} diff --git a/test/onnx/parse/matmulintegertofloat_test.cpp b/test/onnx/parse/matmulintegertofloat_test.cpp new file mode 100644 index 00000000000..fcdbcca9710 --- /dev/null +++ b/test/onnx/parse/matmulintegertofloat_test.cpp @@ -0,0 +1,53 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulintegertofloat_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int8_type, {4, 3}}); + auto x1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::int8_type, {3, 2}}); + auto scale_x0 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {3}}); + auto scale_x1 = mm->add_parameter("4", migraphx::shape{migraphx::shape::float_type, {2}}); + + auto sq_scale_x0 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), scale_x0); + auto sq_scale_x1 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), scale_x1); + auto bc_scale_x0 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x0->get_shape().lens()}}), sq_scale_x0); + auto r0 = mm->add_instruction(migraphx::make_op("dequantizelinear"), x0, bc_scale_x0); + + auto bc_scale_x1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x1->get_shape().lens()}}), sq_scale_x1); + + auto r1 = mm->add_instruction(migraphx::make_op("dequantizelinear"), x1, bc_scale_x1); + mm->add_instruction(migraphx::make_op("dot"), r0, r1); + + auto prog = optimize_onnx("matmulintegertofloat_test.onnx"); + + EXPECT(p == prog); +} diff --git a/test/onnx/parse/matmulintegertofloat_zp_bias_test.cpp b/test/onnx/parse/matmulintegertofloat_zp_bias_test.cpp new file mode 100644 index 00000000000..4408870aac5 --- /dev/null +++ b/test/onnx/parse/matmulintegertofloat_zp_bias_test.cpp @@ -0,0 +1,71 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulintegertofloat_zp_bias_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int8_type, {4, 3}}); + auto x1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::uint8_type, {3, 2}}); + auto scale_x0 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {3}}); + auto scale_x1 = mm->add_parameter("4", migraphx::shape{migraphx::shape::float_type, {2}}); + auto zp_x0 = mm->add_parameter("5", migraphx::shape{migraphx::shape::int8_type, {3}}); + auto zp_x1 = mm->add_parameter("6", migraphx::shape{migraphx::shape::uint8_type, {2}}); + auto bias = mm->add_parameter("7", migraphx::shape{migraphx::shape::float_type, {2}}); + + auto sq_scale_x0 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), scale_x0); + auto sq_scale_x1 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), scale_x1); + + auto sq_zp_x0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), zp_x0); + auto sq_zp_x1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), zp_x1); + + auto bc_scale_x0 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x0->get_shape().lens()}}), sq_scale_x0); + auto bc_zp_x0 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x0->get_shape().lens()}}), sq_zp_x0); + + auto r0 = mm->add_instruction(migraphx::make_op("dequantizelinear"), x0, bc_scale_x0, bc_zp_x0); + + auto bc_scale_x1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x1->get_shape().lens()}}), sq_scale_x1); + + auto bc_zp_x1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x1->get_shape().lens()}}), sq_zp_x1); + + auto r1 = mm->add_instruction(migraphx::make_op("dequantizelinear"), x1, bc_scale_x1, bc_zp_x1); + auto dot = mm->add_instruction(migraphx::make_op("dot"), r0, r1); + + auto mb_bias = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 2}}}), bias); + + mm->add_instruction(migraphx::make_op("sub"), dot, mb_bias); + + auto prog = optimize_onnx("matmulintegertofloat_zp_bias_test.onnx"); + + EXPECT(p == prog); +} diff --git a/test/onnx/parse/matmulintegertofloat_zp_test.cpp b/test/onnx/parse/matmulintegertofloat_zp_test.cpp new file mode 100644 index 00000000000..03cc8daec79 --- /dev/null +++ b/test/onnx/parse/matmulintegertofloat_zp_test.cpp @@ -0,0 +1,64 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulintegertofloat_zp_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int8_type, {4, 3}}); + auto x1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::int8_type, {3, 2}}); + auto scale_x0 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {3}}); + auto scale_x1 = mm->add_parameter("4", migraphx::shape{migraphx::shape::float_type, {2}}); + auto zp_x0 = mm->add_parameter("5", migraphx::shape{migraphx::shape::int8_type, {3}}); + auto zp_x1 = mm->add_parameter("6", migraphx::shape{migraphx::shape::int8_type, {2}}); + + auto sq_scale_x0 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), scale_x0); + auto sq_scale_x1 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), scale_x1); + auto sq_zp_x0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), zp_x0); + auto sq_zp_x1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), zp_x1); + + auto bc_scale_x0 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x0->get_shape().lens()}}), sq_scale_x0); + auto bc_zp_x0 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x0->get_shape().lens()}}), sq_zp_x0); + + auto r0 = mm->add_instruction(migraphx::make_op("dequantizelinear"), x0, bc_scale_x0, bc_zp_x0); + + auto bc_scale_x1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x1->get_shape().lens()}}), sq_scale_x1); + + auto bc_zp_x1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x1->get_shape().lens()}}), sq_zp_x1); + + auto r1 = mm->add_instruction(migraphx::make_op("dequantizelinear"), x1, bc_scale_x1, bc_zp_x1); + mm->add_instruction(migraphx::make_op("dot"), r0, r1); + + auto prog = optimize_onnx("matmulintegertofloat_zp_test.onnx"); + + EXPECT(p == prog); +} diff --git a/test/onnx/verify/matmulintegertofloat_int8_uint8_scales_zp_bias.cpp b/test/onnx/verify/matmulintegertofloat_int8_uint8_scales_zp_bias.cpp new file mode 100644 index 00000000000..417b09d5864 --- /dev/null +++ b/test/onnx/verify/matmulintegertofloat_int8_uint8_scales_zp_bias.cpp @@ -0,0 +1,68 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +TEST_CASE(matmulintegertofloat_int8_uint8_scales_zp_bias_test) +{ + migraphx::program p = read_onnx("matmulintegertofloat_zp_bias_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s0{migraphx::shape::int8_type, {4, 3}}; + std::vector data0 = {-1, 5, -9, -2, 6, 10, -3, 7, -11, -4, 8, 0}; + migraphx::shape s1{migraphx::shape::uint8_type, {3, 2}}; + std::vector data1 = {128, 129, 126, 131, 124, 133}; + + migraphx::shape scale1_s{migraphx::shape::float_type, {3}}; + std::vector scale1 = {1.0f, 1.0f, 1.0f}; + + migraphx::shape scale2_s{migraphx::shape::float_type, {2}}; + std::vector scale2 = {2.0f, 2.0f}; + + migraphx::shape zp1_s{migraphx::shape::int8_type, {3}}; + std::vector zp1 = {1, 2, 3}; + + migraphx::shape zp2_s{migraphx::shape::uint8_type, {2}}; + std::vector zp2 = {3, 5}; + + migraphx::shape bias_s{migraphx::shape::float_type, {2}}; + std::vector bias = {-10.0f, -1.0f}; + + migraphx::parameter_map pp; + pp["1"] = migraphx::argument(s0, data0.data()); + pp["2"] = migraphx::argument(s1, data1.data()); + pp["3"] = migraphx::argument(scale1_s, scale1.data()); + pp["4"] = migraphx::argument(scale2_s, scale2.data()); + pp["5"] = migraphx::argument(zp1_s, zp1.data()); + pp["6"] = migraphx::argument(zp2_s, zp2.data()); + pp["7"] = migraphx::argument(bias_s, bias.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {-2656, -2811, 1938, 2057, -3148, -3315, -490, -495}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} diff --git a/test/onnx/verify/matmulintegertofloat_int8_uint8_scales_zp_bias_3d.cpp b/test/onnx/verify/matmulintegertofloat_int8_uint8_scales_zp_bias_3d.cpp new file mode 100644 index 00000000000..d9254180157 --- /dev/null +++ b/test/onnx/verify/matmulintegertofloat_int8_uint8_scales_zp_bias_3d.cpp @@ -0,0 +1,72 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +TEST_CASE(matmulintegertofloat_int8_uint8_scales_zp_bias_3d_test) +{ + migraphx::program p = read_onnx("matmulintegertofloat_zp_bias_3d_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s0{migraphx::shape::int8_type, {4, 3, 2}}; + std::vector data0 = {-1, 5, -9, -2, 6, 10, -3, 7, -11, -4, 8, 0, + -1, 5, -9, -2, 6, 10, -3, 7, -11, -4, 8, 0}; + migraphx::shape s1{migraphx::shape::uint8_type, {4, 2, 3}}; + std::vector data1(s1.elements(), 0); + std::iota(data1.begin(), data1.end(), 0); + + migraphx::shape scale1_s{migraphx::shape::float_type, {2}}; + std::vector scale1 = {1.0f, 1.0f}; + + migraphx::shape scale2_s{migraphx::shape::float_type, {3}}; + std::vector scale2 = {2.0f, 2.0f, 2.0f}; + + migraphx::shape zp1_s{migraphx::shape::int8_type, {2}}; + std::vector zp1 = {1, 3}; + + migraphx::shape zp2_s{migraphx::shape::uint8_type, {3}}; + std::vector zp2 = {3, 5, 1}; + + migraphx::shape bias_s{migraphx::shape::float_type, {3}}; + std::vector bias = {-10.0f, -1.0f, 0.0f}; + + migraphx::parameter_map pp; + pp["1"] = migraphx::argument(s0, data0.data()); + pp["2"] = migraphx::argument(s1, data1.data()); + pp["3"] = migraphx::argument(scale1_s, scale1.data()); + pp["4"] = migraphx::argument(scale2_s, scale2.data()); + pp["5"] = migraphx::argument(zp1_s, zp1.data()); + pp["6"] = migraphx::argument(zp2_s, zp2.data()); + pp["7"] = migraphx::argument(bias_s, bias.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {22, 13, 12, 70, 91, -60, -20, -53, 66, 34, 25, 24, + -146, -117, -308, 16, -1, 38, 22, 13, 12, -290, -269, -420, + 268, 235, 354, 34, 25, 24, -602, -573, -764, 112, 95, 134}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} diff --git a/test/onnx/verify/negativelogliklihood_kd_dim_weighted.cpp b/test/onnx/verify/negativelogliklihood_kd_dim_weighted.cpp index 06865e637b2..69de5d2c15f 100644 --- a/test/onnx/verify/negativelogliklihood_kd_dim_weighted.cpp +++ b/test/onnx/verify/negativelogliklihood_kd_dim_weighted.cpp @@ -170,7 +170,7 @@ TEST_CASE(negativeloglikelihoodloss_kd_mean_reduction_weighted_test) pp["2"] = migraphx::argument(weight_shape, weight_data.data()); auto result = p.eval(pp).back(); - std::vector result_vector; + std::vector result_vector; result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); std::vector gold = {half{-35.266666666666666}}; EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); @@ -200,7 +200,7 @@ TEST_CASE(negativeloglikelihoodloss_kd_mean_reduction_weighted_test2) migraphx::shape label_shape{migraphx::shape::int32_type, {2, 2}}; std::vector label_data = {2, 1, 0, 2}; migraphx::shape weight_shape{migraphx::shape::half_type, {3}}; - std::vector weight_data = {half(0.2), half(0.3), half(0.1)}; + std::vector weight_data = {half(0.2), half(0.3), half(0.1)}; migraphx::parameter_map pp; pp["0"] = migraphx::argument(score_shape, score_data.data()); @@ -208,7 +208,7 @@ TEST_CASE(negativeloglikelihoodloss_kd_mean_reduction_weighted_test2) pp["2"] = migraphx::argument(weight_shape, weight_data.data()); auto result = p.eval(pp).back(); - std::vector result_vector; + std::vector result_vector; result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); std::vector gold = {half{-1.5714285714285714}}; EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); diff --git a/test/onnx/verify/softmaxcrossentropyloss_kd_dim_weighted.cpp b/test/onnx/verify/softmaxcrossentropyloss_kd_dim_weighted.cpp index 14b5a0da963..34fb82c9070 100644 --- a/test/onnx/verify/softmaxcrossentropyloss_kd_dim_weighted.cpp +++ b/test/onnx/verify/softmaxcrossentropyloss_kd_dim_weighted.cpp @@ -180,7 +180,7 @@ TEST_CASE(softmaxcrossentropyloss_kd_mean_reduction_weighted_test) pp["2"] = migraphx::argument(weight_shape, weight_data.data()); auto result = p.eval(pp).back(); - std::vector result_vector; + std::vector result_vector; result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); std::vector gold = {half{1.38629436}}; EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); @@ -207,7 +207,7 @@ TEST_CASE(softmaxcrossentropyloss_kd_mean_reduction_uneven_weighted_test) pp["2"] = migraphx::argument(weight_shape, weight_data.data()); auto result = p.eval(pp).back(); - std::vector result_vector; + std::vector result_vector; result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); std::vector gold = {half{1.38629436}}; diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index 8d08455d814..66d54c8a460 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -201,6 +201,21 @@ TEST_CASE(binary_dyn_static_error) throws_shape(migraphx::make_op("add"), a_shape, b_shape); } +TEST_CASE(bit_cast_typesize_mismatch) +{ + migraphx::shape a_shape{migraphx::shape::int8_type, {1, 4, 4}}; + throws_shape(migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::int32_type}}), + a_shape); +} + +TEST_CASE(bit_cast_dyn) +{ + migraphx::shape a_shape{migraphx::shape::int8_type, {{1, 1}, {4, 8}, {4, 8}}}; + expect_shape(migraphx::shape{migraphx::shape::uint8_type, {{1, 1}, {4, 8}, {4, 8}}}, + migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::uint8_type}}), + a_shape); +} + TEST_CASE(bitwise_and_not_integral_error) { migraphx::shape a_shape{migraphx::shape::float_type, {1, 4, 4}}; diff --git a/test/quantization.cpp b/test/quantization.cpp index 3c8968cfdf4..6eafcca377c 100644 --- a/test/quantization.cpp +++ b/test/quantization.cpp @@ -30,8 +30,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -261,8 +261,9 @@ TEST_CASE(param_add_sub) }; auto p0 = create_program_float(); - migraphx::run_passes( - p0, {migraphx::quantize_fp16_pass{{"all"}}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(p0, + {migraphx::truncate_float_pass{{"all"}, migraphx::shape::half_type}, + migraphx::dead_code_elimination{}}); EXPECT(p0 == create_program_fp16()); auto p1 = create_program_float(); @@ -669,7 +670,6 @@ TEST_CASE(dot_float) auto pb = mm->add_parameter("b", sb); auto zp = mm->add_literal(static_cast(0)); auto scale = mm->add_literal(10.0f); - auto zp_out = mm->add_literal(std::int32_t{0}); auto scale_a = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale); auto zp_a = @@ -684,10 +684,7 @@ TEST_CASE(dot_float) auto scale_mb = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), scale); auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb); - zp_out = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), zp_out); - auto r = - mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale, zp_out); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale); mm->add_return({r}); return p; @@ -704,11 +701,11 @@ TEST_CASE(dot_float) migraphx::dead_code_elimination{}}); auto qp = create_int8_quantized_prog(); - EXPECT(p == qp); + EXPECT(p.sort() == qp.sort()); optimize_prog_int8(p); auto op = create_int8_optimized_prog(); - EXPECT(p == op); + EXPECT(p.sort() == op.sort()); } TEST_CASE(dot_double_2args) @@ -784,11 +781,7 @@ TEST_CASE(dot_double_2args) migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale_b_lit); auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_a_mb, scale_b_mb); - auto zp_out = mm->add_literal(std::int32_t{0}); - zp_out = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), zp_out); - auto r = - mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale, zp_out); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale); mm->add_return({r}); return p; }; @@ -855,7 +848,6 @@ TEST_CASE(dot_half_1arg) auto zp = mm->add_literal(static_cast(0)); auto scale_lit = mm->add_literal(migraphx::literal({sa.type()}, {10.0})); - auto zp_out = mm->add_literal(std::int32_t{0}); auto scale = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_lit); zp = @@ -863,10 +855,7 @@ TEST_CASE(dot_half_1arg) auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp); auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx); auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale, scale); - zp_out = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), zp_out); - auto r = - mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale, zp_out); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale); mm->add_return({r}); return p; }; @@ -922,11 +911,7 @@ TEST_CASE(conv_float) migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), scale_lit); auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb); - auto zp_out = mm->add_literal(std::int32_t{0}); - zp_out = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), zp_out); - auto r = - mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale, zp_out); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale); mm->add_return({r}); return p; @@ -1004,11 +989,7 @@ TEST_CASE(conv_half) migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), scale_lit); auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb); - auto zp_out = mm->add_literal(std::int32_t{0}); - zp_out = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), zp_out); - auto r = - mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale, zp_out); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale); mm->add_return({r}); return p; @@ -1256,10 +1237,7 @@ TEST_CASE(int8_subgraph) auto s1_mb = then_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), s1); auto so = then_mod->add_instruction(migraphx::make_op("mul"), s1_mb, s1_mb); - auto zp_out = then_mod->add_literal(std::int32_t{0}); - zp_out = then_mod->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), zp_out); - auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so, zp_out); + auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so); then_mod->add_return({r}); migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}}; @@ -1285,13 +1263,8 @@ TEST_CASE(int8_subgraph) auto ssw_mb = else_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", qconv->get_shape().lens()}}), ssw_lit); - auto so1 = else_mod->add_instruction(migraphx::make_op("mul"), sax, ssw_mb); - auto zp1_out = else_mod->add_literal(std::int32_t{0}); - zp1_out = else_mod->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", qconv->get_shape().lens()}}), - zp1_out); - auto r1 = - else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1, zp1_out); + auto so1 = else_mod->add_instruction(migraphx::make_op("mul"), sax, ssw_mb); + auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1); else_mod->add_return({r1}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); diff --git a/test/ref/add.cpp b/test/ref/add.cpp index ec9b7a2c40e..525e74dfa11 100644 --- a/test/ref/add.cpp +++ b/test/ref/add.cpp @@ -137,6 +137,25 @@ TEST_CASE(fp16_test) EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); } +TEST_CASE(bf16_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::bf16_type, {1}}; + migraphx::bf16 a{1.5}; + migraphx::bf16 b{2.5}; + migraphx::bf16 c{4.0}; + auto l0 = mm->add_literal(migraphx::literal{s, {a}}); + auto l1 = mm->add_literal(migraphx::literal{s, {b}}); + mm->add_instruction(migraphx::make_op("add"), l0, l1); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(1); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{c}; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + TEST_CASE(fp32_fp16_test) { auto create_program = [] { diff --git a/test/ref/bit_cast.cpp b/test/ref/bit_cast.cpp new file mode 100644 index 00000000000..4f9438ef4fd --- /dev/null +++ b/test/ref/bit_cast.cpp @@ -0,0 +1,75 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include + +#include + +TEST_CASE(bit_cast_fp8) +{ + using migraphx::fp8::fp8e4m3fn; + using migraphx::fp8::fp8e4m3fnuz; + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::fp8e4m3fn_type, {2, 2}}; + std::vector data; + data.push_back(fp8e4m3fn{26.0f}); + data.push_back(fp8e4m3fn{3.0f}); + data.push_back(fp8e4m3fn{96.0f}); + data.push_back(fp8e4m3fn{-1.25f}); + auto lit = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction( + migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), lit); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold; + gold.push_back(fp8e4m3fnuz{13.0f}); + gold.push_back(fp8e4m3fnuz{1.5f}); + gold.push_back(fp8e4m3fnuz{48.0f}); + gold.push_back(fp8e4m3fnuz{-0.625f}); + EXPECT(results_vector == gold); +} + +TEST_CASE(bit_cast_uint8) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int8_type, {2, 2}}; + std::vector data = {23, -3, 0, -1}; + auto lit = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction( + migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::uint8_type}}), lit); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {23, 253, 0, 255}; + EXPECT(results_vector == gold); +} diff --git a/test/ref/isinf.cpp b/test/ref/isinf.cpp index 900ebc13fda..4d86e7f1b27 100644 --- a/test/ref/isinf.cpp +++ b/test/ref/isinf.cpp @@ -83,6 +83,25 @@ TEST_CASE(isinf_half_test) EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); } +TEST_CASE(isinf_bf16_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::bf16_type, {2, 3}}; + auto inf_val = std::numeric_limits::infinity(); + migraphx::bf16 a{1.2}; + migraphx::bf16 b{5.2}; + std::vector data0 = {a, b, inf_val, -inf_val, b, a}; + auto l1 = mm->add_literal(migraphx::literal{s, data0}); + mm->add_instruction(migraphx::make_op("isinf"), l1); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0, 0, 1, 1, 0, 0}; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + TEST_CASE(isinf_dyn_test) { migraphx::program p; diff --git a/test/ref/quantizelinear.cpp b/test/ref/quantizelinear.cpp index 85c762f1b8e..61917e39c88 100644 --- a/test/ref/quantizelinear.cpp +++ b/test/ref/quantizelinear.cpp @@ -139,4 +139,5 @@ void quantizelinear_4() std::vector gold{2.5, 1.75, -1.75, 1.5, -1, 1, 0.625, 8}; EXPECT(results_vector == gold); } +TEST_CASE_REGISTER(quantizelinear_4); TEST_CASE_REGISTER(quantizelinear_4); diff --git a/test/shape_transform_descriptor.cpp b/test/shape_transform_descriptor.cpp index 8a5b7cf34f2..ca9be15c80d 100644 --- a/test/shape_transform_descriptor.cpp +++ b/test/shape_transform_descriptor.cpp @@ -24,6 +24,8 @@ */ #include #include +#include +#include #include using migraphx::make_op; @@ -35,6 +37,7 @@ using d_axes = std::vector>; using ops = std::vector; using dimension = shape_transform_descriptor::dimension; using sub = dimension::sub; +using axes_map = std::vector>; all_lens get_all_lens(const shape_transform_descriptor& d) { @@ -76,6 +79,34 @@ all_axes get_all_axes(const shape_transform_descriptor& d) return result; } +std::vector run_shape_transforms(const std::vector& dims, + const std::vector& ops) +{ + migraphx::shape s{migraphx::shape::int64_type, dims}; + std::vector data(s.elements()); + std::iota(data.begin(), data.end(), 0); + + migraphx::program p; + auto* mm = p.get_main_module(); + auto start = mm->add_literal(s, data); + for(const auto& op : ops) + start = mm->add_instruction(op, start); + mm->add_return({start}); + + auto result = p.eval({}).at(0); + return result.to_vector(); +} + +std::vector +check_optimize_shape_transforms(const std::vector& dims, + const std::vector& ops) +{ + auto result = migraphx::optimize_shape_transforms(dims, ops); + CHECK(run_shape_transforms(dims, ops) == run_shape_transforms(dims, result)); + CHECK(result == migraphx::optimize_shape_transforms(dims, result)); + return result; +} + template shape_transform_descriptor make_descriptor(const std::vector& dims, const Ts&... xs) { @@ -84,6 +115,15 @@ shape_transform_descriptor make_descriptor(const std::vector& dims, return desc; } +template +shape_transform_descriptor make_simple_descriptor(const std::vector& dims, + const Ts&... xs) +{ + auto desc = make_descriptor(dims, xs...); + desc.simplify(); + return desc; +} + TEST_CASE(dimension_len) { dimension dim; @@ -115,7 +155,7 @@ TEST_CASE(record_reshape_trailing_1s) EXPECT(get_final_lens(desc) == final_lens{3, 4, 4, 1, 1}); EXPECT(get_all_lens(desc) == all_lens{{3}, {4}, {4}, {1}, {1}}); EXPECT(get_all_axes(desc) == - all_axes{d_axes{{0}}, d_axes{{1}}, d_axes{{2}}, d_axes{{}}, d_axes{{}}}); + all_axes{d_axes{{0}}, d_axes{{1}}, d_axes{{2, 0}}, d_axes{{2, 1}}, d_axes{{2, 2}}}); } TEST_CASE(record_reshape_merge) @@ -158,7 +198,7 @@ TEST_CASE(record_reshape_squeeze_trailing_1s) make_op("reshape", {{"dims", {3, 4, 4}}})); EXPECT(get_final_lens(desc) == final_lens{3, 4, 4}); EXPECT(get_all_lens(desc) == all_lens{{3}, {4}, {4}}); - EXPECT(get_all_axes(desc) == all_axes{d_axes{{0}}, d_axes{{1}}, d_axes{{2}}}); + EXPECT(get_all_axes(desc) == all_axes{d_axes{{0}}, d_axes{{1}}, d_axes{{2, 0}}}); } TEST_CASE(record_reshape_non_divisible_fail) @@ -234,41 +274,41 @@ TEST_CASE(simplify_dimension_remove_1_dim) TEST_CASE(optimize_transpose_transpose) { - EXPECT(migraphx::optimize_shape_transforms( - {3, 5, 2}, - { - make_op("transpose", {{"permutation", {0, 2, 1}}}), - make_op("transpose", {{"permutation", {1, 0, 2}}}), - }) == ops{ - make_op("transpose", {{"permutation", {2, 0, 1}}}), - }); + EXPECT(check_optimize_shape_transforms({3, 5, 2}, + { + make_op("transpose", {{"permutation", {0, 2, 1}}}), + make_op("transpose", {{"permutation", {1, 0, 2}}}), + }) == + ops{ + make_op("transpose", {{"permutation", {2, 0, 1}}}), + }); } TEST_CASE(optimize_reshape_reshape1) { - EXPECT(migraphx::optimize_shape_transforms({3, 5, 2}, - { - make_op("reshape", {{"dims", {30}}}), - make_op("reshape", {{"dims", {3, 10}}}), - }) == ops{ - make_op("reshape", {{"dims", {3, 10}}}), - }); + EXPECT(check_optimize_shape_transforms({3, 5, 2}, + { + make_op("reshape", {{"dims", {30}}}), + make_op("reshape", {{"dims", {3, 10}}}), + }) == ops{ + make_op("reshape", {{"dims", {3, 10}}}), + }); } TEST_CASE(optimize_reshape_reshape2) { - EXPECT(migraphx::optimize_shape_transforms({15, 4}, - { - make_op("reshape", {{"dims", {3, 5, 2, 2}}}), - make_op("reshape", {{"dims", {15, 2, 2}}}), - }) == ops{ - make_op("reshape", {{"dims", {15, 2, 2}}}), - }); + EXPECT(check_optimize_shape_transforms({15, 4}, + { + make_op("reshape", {{"dims", {3, 5, 2, 2}}}), + make_op("reshape", {{"dims", {15, 2, 2}}}), + }) == ops{ + make_op("reshape", {{"dims", {15, 2, 2}}}), + }); } TEST_CASE(optimize_reshape_transpose_reshape_to_none) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {6, 5, 2}, { make_op("reshape", {{"dims", {6, 5, 2, 1, 1}}}), @@ -279,22 +319,22 @@ TEST_CASE(optimize_reshape_transpose_reshape_to_none) TEST_CASE(optimize_reshape_transpose_reshape_to_same) { - EXPECT(migraphx::optimize_shape_transforms( - {1, 112, 56, 56}, + EXPECT(check_optimize_shape_transforms( + {1, 112, 7, 7}, { - make_op("reshape", {{"dims", {1, 4, 28, 56, 56}}}), + make_op("reshape", {{"dims", {1, 4, 28, 7, 7}}}), make_op("transpose", {{"permutation", {0, 2, 1, 3, 4}}}), - make_op("reshape", {{"dims", {1, 112, 56, 56}}}), + make_op("reshape", {{"dims", {1, 112, 7, 7}}}), }) == ops{ - make_op("reshape", {{"dims", {1, 4, 28, 56, 56}}}), + make_op("reshape", {{"dims", {1, 4, 28, 7, 7}}}), make_op("transpose", {{"permutation", {0, 2, 1, 3, 4}}}), - make_op("reshape", {{"dims", {1, 112, 56, 56}}}), + make_op("reshape", {{"dims", {1, 112, 7, 7}}}), }); } TEST_CASE(optimize_reshape_transpose_reshape_to_transpose) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {6, 5, 2}, { make_op("reshape", {{"dims", {2, 3, 5, 2}}}), @@ -307,20 +347,20 @@ TEST_CASE(optimize_reshape_transpose_reshape_to_transpose) TEST_CASE(optimize_reshape_transpose_reshape_to_reshape) { - EXPECT(migraphx::optimize_shape_transforms( - {6, 5, 2}, - { - make_op("reshape", {{"dims", {6, 5, 2, 1}}}), - make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), - make_op("reshape", {{"dims", {6, 10}}}), - }) == ops{ - make_op("reshape", {{"dims", {6, 10}}}), - }); + EXPECT( + check_optimize_shape_transforms({6, 5, 2}, + { + make_op("reshape", {{"dims", {6, 5, 2, 1}}}), + make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), + make_op("reshape", {{"dims", {6, 10}}}), + }) == ops{ + make_op("reshape", {{"dims", {6, 10}}}), + }); } TEST_CASE(optimize_multibroadcast_transpose_reshape) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {1, 5, 2}, { make_op("multibroadcast", {{"out_lens", {20, 5, 2}}}), @@ -335,7 +375,7 @@ TEST_CASE(optimize_multibroadcast_transpose_reshape) TEST_CASE(optimize_resize1) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {3, 4, 4}, { make_op("reshape", {{"dims", {3, 1, 4, 1, 4}}}), @@ -350,7 +390,7 @@ TEST_CASE(optimize_resize1) TEST_CASE(optimize_resize2) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {1, 1, 2, 2}, { make_op("reshape", {{"dims", {1, 1, 2, 1, 2, 1}}}), @@ -366,54 +406,53 @@ TEST_CASE(optimize_resize2) TEST_CASE(optimize_reshape_2_squeeze) { - EXPECT(migraphx::optimize_shape_transforms({3, 1, 5, 1, 2, 1, 1}, - { - make_op("reshape", {{"dims", {3, 5, 2}}}), - }) == - ops{ - make_op("squeeze", {{"axes", {1, 3, 5, 6}}}), - }); + EXPECT(check_optimize_shape_transforms({3, 1, 5, 1, 2, 1, 1}, + { + make_op("reshape", {{"dims", {3, 5, 2}}}), + }) == ops{ + make_op("squeeze", {{"axes", {1, 3, 5, 6}}}), + }); } TEST_CASE(optimize_reshape_2_unsqueeze) { - EXPECT(migraphx::optimize_shape_transforms( - {3, 5, 2}, - { - make_op("reshape", {{"dims", {3, 1, 5, 1, 2, 1, 1}}}), - }) == ops{ - make_op("unsqueeze", {{"axes", {1, 3, 5, 6}}}), - }); + EXPECT( + check_optimize_shape_transforms({3, 5, 2}, + { + make_op("reshape", {{"dims", {3, 1, 5, 1, 2, 1, 1}}}), + }) == ops{ + make_op("unsqueeze", {{"axes", {1, 3, 5, 6}}}), + }); } TEST_CASE(optimize_unsqueeze_multibroadcast) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {32, 10}, { make_op("unsqueeze", {{"axes", {0, 3, 4}}}), - make_op("multibroadcast", {{"out_lens", {256, 32, 10, 16, 16}}}), + make_op("multibroadcast", {{"out_lens", {4, 32, 10, 16, 16}}}), }) == ops{ - make_op("broadcast", {{"axis", 1}, {"out_lens", {256, 32, 10, 16, 16}}}), + make_op("broadcast", {{"axis", 1}, {"out_lens", {4, 32, 10, 16, 16}}}), }); } TEST_CASE(optimize_multibroadcast_reshape) { - EXPECT(migraphx::optimize_shape_transforms( - {1, 4, 1}, - { - make_op("multibroadcast", {{"out_lens", {2, 4, 6}}}), - make_op("reshape", {{"dims", {2, 2, 2, 6}}}), - }) == ops{ - make_op("reshape", {{"dims", {1, 2, 2, 1}}}), - make_op("multibroadcast", {{"out_lens", {2, 2, 2, 6}}}), - }); + EXPECT(check_optimize_shape_transforms({1, 4, 1}, + { + make_op("multibroadcast", {{"out_lens", {2, 4, 6}}}), + make_op("reshape", {{"dims", {2, 2, 2, 6}}}), + }) == + ops{ + make_op("reshape", {{"dims", {1, 2, 2, 1}}}), + make_op("multibroadcast", {{"out_lens", {2, 2, 2, 6}}}), + }); } -TEST_CASE(optimize_squeeze_broadcast) +TEST_CASE(optimize_squeeze_broadcast1) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {256, 1, 1}, { make_op("squeeze"), @@ -424,9 +463,22 @@ TEST_CASE(optimize_squeeze_broadcast) }); } +TEST_CASE(optimize_squeeze_broadcast2) +{ + EXPECT(check_optimize_shape_transforms( + {1, 128, 1}, + { + make_op("squeeze", {{"axes", {0}}}), + make_op("multibroadcast", {{"out_lens", {128, 768}}}), + }) == ops{ + make_op("squeeze", {{"axes", {0}}}), + make_op("multibroadcast", {{"out_lens", {128, 768}}}), + }); +} + TEST_CASE(optimize_squeeze_unsqueeze_broadcast_to_broadcast) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {256}, { make_op("unsqueeze", {{"axes", {0}}}), @@ -439,7 +491,7 @@ TEST_CASE(optimize_squeeze_unsqueeze_broadcast_to_broadcast) TEST_CASE(optimize_transpose_reshape_to_transpose) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {3, 3, 3, 1}, { make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), @@ -451,14 +503,329 @@ TEST_CASE(optimize_transpose_reshape_to_transpose) TEST_CASE(optimize_scalar_broadcast_unsqueeze) { - EXPECT(migraphx::optimize_shape_transforms({1}, - { - make_op("multibroadcast", {{"out_lens", {2}}}), - make_op("unsqueeze", {{"axes", {1}}}), - }) == + EXPECT(check_optimize_shape_transforms({1}, + { + make_op("multibroadcast", {{"out_lens", {2}}}), + make_op("unsqueeze", {{"axes", {1}}}), + }) == ops{ make_op("multibroadcast", {{"out_lens", {2, 1}}}), }); } +TEST_CASE(optimize_broadcast_reshape_transpose) +{ + EXPECT(check_optimize_shape_transforms( + {2, 16, 1}, + { + make_op("multibroadcast", {{"out_lens", {2, 16, 10240}}}), + make_op("reshape", {{"dims", {2, 160, 32, 32}}}), + make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), + }) == ops{ + make_op("unsqueeze", {{"axes", {3, 4}}}), + make_op("transpose", {{"permutation", {0, 3, 4, 1, 2}}}), + make_op("multibroadcast", {{"out_lens", {2, 1, 1, 16, 10}}}), + make_op("reshape", {{"dims", {2, 1, 1, 160}}}), + make_op("multibroadcast", {{"out_lens", {2, 32, 32, 160}}}), + }); +} + +TEST_CASE(optimize_multibroadcast_transpose) +{ + EXPECT(check_optimize_shape_transforms( + {320, 1, 1}, + { + make_op("multibroadcast", {{"out_lens", {2, 320, 64, 64}}}), + make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), + }) == ops{ + make_op("unsqueeze", {{"axes", {0}}}), + make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), + make_op("multibroadcast", {{"out_lens", {2, 64, 64, 320}}}), + }); +} + +TEST_CASE(optimize_unsqueeze_transpose_squeeze_multibroadcast) +{ + EXPECT(check_optimize_shape_transforms( + {320, 1, 1}, + { + make_op("unsqueeze", {{"axes", {0}}}), + make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), + make_op("squeeze", {{"axes", {0, 1}}}), + make_op("multibroadcast", {{"out_lens", {320, 320}}}), + }) == ops{ + make_op("multibroadcast", {{"out_lens", {320, 1, 320}}}), + make_op("squeeze", {{"axes", {1}}}), + }); +} + +TEST_CASE(optimize_squeeze_multibroadcast_transpose) +{ + EXPECT(check_optimize_shape_transforms( + {16, 1, 16}, + { + make_op("squeeze", {{"axes", {1}}}), + make_op("multibroadcast", {{"out_lens", {4, 16, 16}}}), + make_op("transpose", {{"permutation", {1, 0, 2}}}), + }) == ops{ + make_op("multibroadcast", {{"out_lens", {16, 4, 16}}}), + }); +} + +TEST_CASE(common_dims_reshape_less) +{ + auto desc = + make_simple_descriptor({2, 32, 40, 8}, make_op("reshape", {{"dims", {2, 1280, 8}}})); + EXPECT(desc.common_dims() == final_lens{2, 32, 40, 8}); + EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {1}, {2}, {3}}); + EXPECT(desc.common_axes_map_from_dst() == axes_map{{0}, {1, 2}, {3}}); + EXPECT(desc.generate_common_from_src() == ops{}); + EXPECT(desc.generate_common_from_dst() == ops{make_op("reshape", {{"dims", {2, 32, 40, 8}}})}); + EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {2, 1280, 8}}})}); +} + +TEST_CASE(common_dims_reshape1) +{ + auto desc = + make_simple_descriptor({2, 32, 2560}, make_op("reshape", {{"dims", {2, 1280, 8, 8}}})); + EXPECT(desc.common_dims() == final_lens{2, 32, 40, 8, 8}); + EXPECT(desc.common_axes_map_from_src() == axes_map{{{0}, {1}, {2, 3, 4}}}); + EXPECT(desc.common_axes_map_from_dst() == axes_map{{0}, {1, 2}, {3}, {4}}); + EXPECT(desc.generate_common_from_src() == + ops{make_op("reshape", {{"dims", {2, 32, 40, 8, 8}}})}); + EXPECT(desc.generate_common_from_dst() == + ops{make_op("reshape", {{"dims", {2, 32, 40, 8, 8}}})}); + EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {2, 1280, 8, 8}}})}); +} + +TEST_CASE(common_dims_reshape2) +{ + auto desc = + make_simple_descriptor({2, 1280, 8, 8}, make_op("reshape", {{"dims", {2, 32, 2560}}})); + EXPECT(desc.common_dims() == final_lens{2, 32, 40, 8, 8}); + EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {1, 2}, {3}, {4}}); + EXPECT(desc.common_axes_map_from_dst() == axes_map{{{0}, {1}, {2, 3, 4}}}); + EXPECT(desc.generate_common_from_src() == + ops{make_op("reshape", {{"dims", {2, 32, 40, 8, 8}}})}); + EXPECT(desc.generate_common_from_dst() == + ops{make_op("reshape", {{"dims", {2, 32, 40, 8, 8}}})}); + EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {2, 32, 2560}}})}); +} + +TEST_CASE(common_dims_reshape3) +{ + auto desc = + make_simple_descriptor({2, 32, 4096}, make_op("reshape", {{"dims", {4, 16, 64, 64}}})); + + EXPECT(desc.common_dims() == final_lens{2, 2, 16, 64, 64}); + EXPECT(desc.common_dims({2, 1, 4096}) == final_lens{2, 1, 1, 64, 64}); + EXPECT(desc.common_dims({2, 32, 1}) == final_lens{2, 2, 16, 1, 1}); + + EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {1, 2}, {3, 4}}); + EXPECT(desc.common_axes_map_from_dst() == axes_map{{0, 1}, {2}, {3}, {4}}); + + EXPECT(desc.generate_common_from_src() == + ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_src({2, 32, 1}) == + ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); + EXPECT(desc.generate_common_from_src({2, 1, 4096}) == + ops{make_op("reshape", {{"dims", {2, 1, 1, 64, 64}}})}); + + EXPECT(desc.generate_common_from_dst() == + ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_dst({4, 16, 1, 1}) == + ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); + EXPECT(desc.generate_common_from_dst({4, 1, 64, 64}) == + ops{make_op("reshape", {{"dims", {2, 2, 1, 64, 64}}})}); + + EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {4, 16, 64, 64}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 1, 64, 64}) == + ops{make_op("reshape", {{"dims", {4, 1, 64, 64}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 16, 1, 1}) == + ops{make_op("reshape", {{"dims", {4, 16, 1, 1}}})}); + EXPECT(desc.generate_dst_from_common({2, 1, 16, 64, 64}) == + ops{make_op("squeeze", {{"axes", {1}}})}); +} + +TEST_CASE(common_dims_reshape4) +{ + auto desc = + make_simple_descriptor({4, 16, 64, 64}, make_op("reshape", {{"dims", {2, 32, 4096}}})); + + EXPECT(desc.common_dims() == final_lens{2, 2, 16, 64, 64}); + EXPECT(desc.common_dims({4, 16, 1, 1}) == final_lens{2, 2, 16, 1, 1}); + EXPECT(desc.common_dims({4, 1, 64, 64}) == final_lens{2, 2, 1, 64, 64}); + + EXPECT(desc.common_axes_map_from_src() == axes_map{{0, 1}, {2}, {3}, {4}}); + EXPECT(desc.common_axes_map_from_dst() == axes_map{{0}, {1, 2}, {3, 4}}); + + EXPECT(desc.generate_common_from_dst() == + ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_dst({2, 32, 1}) == + ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); + EXPECT(desc.generate_common_from_dst({2, 1, 4096}) == + ops{make_op("reshape", {{"dims", {2, 1, 1, 64, 64}}})}); + + EXPECT(desc.generate_common_from_src() == + ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_src({4, 16, 1, 1}) == + ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); + EXPECT(desc.generate_common_from_src({4, 1, 64, 64}) == + ops{make_op("reshape", {{"dims", {2, 2, 1, 64, 64}}})}); + + EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {2, 32, 4096}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 1, 64, 64}) == + ops{make_op("reshape", {{"dims", {2, 2, 4096}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 16, 1, 1}) == + ops{make_op("reshape", {{"dims", {2, 32, 1}}})}); + EXPECT(desc.generate_dst_from_common({2, 1, 16, 64, 64}) == + ops{make_op("reshape", {{"dims", {2, 16, 4096}}})}); +} + +TEST_CASE(common_dims_transpose_reshape) +{ + auto desc = make_simple_descriptor({2, 16, 64, 64}, + make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), + make_op("reshape", {{"dims", {2, 32, 2048}}})); + EXPECT(desc.common_dims() == final_lens{2, 32, 2, 64, 16}); + + EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {4}, {1, 2}, {3}}); + EXPECT(desc.common_axes_map_from_dst() == axes_map{{0}, {1}, {2, 3, 4}}); + + EXPECT(desc.generate_common_from_dst() == + ops{make_op("reshape", {{"dims", {2, 32, 2, 64, 16}}})}); + EXPECT(desc.generate_common_from_dst({2, 32, 1}) == + ops{make_op("unsqueeze", {{"axes", {3, 4}}})}); + EXPECT(desc.generate_common_from_dst({2, 1, 2048}) == + ops{make_op("reshape", {{"dims", {2, 1, 2, 64, 16}}})}); + + EXPECT(desc.generate_common_from_src() == + ops{make_op("reshape", {{"dims", {2, 16, 32, 2, 64}}}), + make_op("transpose", {{"permutation", {0, 2, 3, 4, 1}}})}); + EXPECT(desc.generate_common_from_src({2, 16, 1, 1}) == + ops{make_op("unsqueeze", {{"axes", {3}}}), + make_op("transpose", {{"permutation", {0, 2, 3, 4, 1}}})}); + EXPECT(desc.generate_common_from_src({2, 1, 64, 64}) == + ops{make_op("reshape", {{"dims", {2, 1, 32, 2, 64}}}), + make_op("transpose", {{"permutation", {0, 2, 3, 4, 1}}})}); + + EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {2, 32, 2048}}})}); + EXPECT(desc.generate_dst_from_common({2, 1, 2, 64, 16}) == + ops{make_op("reshape", {{"dims", {2, 1, 2048}}})}); + EXPECT(desc.generate_dst_from_common({2, 1, 1, 1, 16}) == + ops{make_op("squeeze", {{"axes", {2, 3}}})}); + EXPECT(desc.generate_dst_from_common({2, 32, 2, 64, 1}) == + ops{make_op("reshape", {{"dims", {2, 32, 128}}})}); +} + +TEST_CASE(common_dims_broadcast_reshape) +{ + auto desc = make_simple_descriptor({2, 32, 1}, + make_op("multibroadcast", {{"out_lens", {2, 32, 4096}}}), + make_op("reshape", {{"dims", {4, 16, 64, 64}}})); + + EXPECT(desc.common_dims() == final_lens{2, 2, 16, 64, 64}); + EXPECT(desc.common_dims({2, 1, 1}) == final_lens{2, 1, 1, 64, 64}); + EXPECT(desc.common_dims({2, 1, 4096}) == final_lens{2, 1, 1, 64, 64}); + EXPECT(desc.common_dims({2, 32, 4096}) == final_lens{2, 2, 16, 64, 64}); + + EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {1, 2}, {3, 4}}); + EXPECT(desc.common_axes_map_from_dst() == axes_map{{0, 1}, {2}, {3}, {4}}); + + EXPECT(desc.generate_common_from_src() == + ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}}), + make_op("multibroadcast", {{"out_lens", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_src({2, 1, 1}) == + ops{make_op("unsqueeze", {{"axes", {2, 4}}}), + make_op("multibroadcast", {{"out_lens", {2, 1, 1, 64, 64}}})}); + + EXPECT(desc.generate_common_from_dst() == + ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_dst({4, 16, 1, 1}) == + ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); + EXPECT(desc.generate_common_from_dst({4, 1, 64, 64}) == + ops{make_op("reshape", {{"dims", {2, 2, 1, 64, 64}}})}); + + EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {4, 16, 64, 64}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 1, 64, 64}) == + ops{make_op("reshape", {{"dims", {4, 1, 64, 64}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 16, 1, 1}) == + ops{make_op("reshape", {{"dims", {4, 16, 1, 1}}})}); + EXPECT(desc.generate_dst_from_common({2, 1, 16, 64, 64}) == + ops{make_op("squeeze", {{"axes", {1}}})}); +} + +TEST_CASE(common_dims_resize) +{ + auto desc = + make_simple_descriptor({4, 16, 32, 32}, + make_op("reshape", {{"dims", {4, 16, 32, 1, 32, 1}}}), + make_op("multibroadcast", {{"out_lens", {4, 16, 32, 2, 32, 2}}}), + make_op("reshape", {{"dims", {4, 16, 64, 64}}})); + + EXPECT(desc.common_dims() == final_lens{4, 16, 32, 2, 32, 2}); + EXPECT(desc.common_dims({4, 16, 1, 1}) == final_lens{4, 16, 1, 2, 1, 2}); + EXPECT(desc.common_dims({4, 1, 32, 32}) == final_lens{4, 1, 32, 2, 32, 2}); + + EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {1}, {2}, {4}}); + EXPECT(desc.common_axes_map_from_dst() == axes_map{{0}, {1}, {2, 3}, {4, 5}}); + + EXPECT(desc.generate_common_from_src() == + ops{make_op("unsqueeze", {{"axes", {3, 5}}}), + make_op("multibroadcast", {{"out_lens", {4, 16, 32, 2, 32, 2}}})}); + EXPECT(desc.generate_common_from_src({4, 16, 1, 1}) == + ops{make_op("unsqueeze", {{"axes", {3, 5}}}), + make_op("multibroadcast", {{"out_lens", {4, 16, 1, 2, 1, 2}}})}); + EXPECT(desc.generate_common_from_src({4, 1, 32, 32}) == + ops{make_op("unsqueeze", {{"axes", {3, 5}}}), + make_op("multibroadcast", {{"out_lens", {4, 1, 32, 2, 32, 2}}})}); + + EXPECT(desc.generate_common_from_dst() == + ops{make_op("reshape", {{"dims", {4, 16, 32, 2, 32, 2}}})}); + EXPECT(desc.generate_common_from_dst({4, 16, 1, 1}) == + ops{make_op("unsqueeze", {{"axes", {3, 5}}})}); + EXPECT(desc.generate_common_from_dst({4, 1, 64, 64}) == + ops{make_op("reshape", {{"dims", {4, 1, 32, 2, 32, 2}}})}); + + EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {4, 16, 64, 64}}})}); + EXPECT(desc.generate_dst_from_common({4, 16, 1, 2, 1, 2}) == + ops{make_op("squeeze", {{"axes", {2, 4}}})}); + EXPECT(desc.generate_dst_from_common({4, 1, 32, 2, 32, 2}) == + ops{make_op("reshape", {{"dims", {4, 1, 64, 64}}})}); +} + +TEST_CASE(rebase_reshape_broadcast) +{ + auto base_desc = + make_simple_descriptor({3, 4, 64, 1}, + make_op("reshape", {{"dims", {12, 8, 8, 1, 1}}}), + make_op("multibroadcast", {{"out_lens", {12, 8, 8, 2, 2}}})); + + { + auto desc = base_desc.rebase({3, 4, 64, 4}); + EXPECT(get_final_lens(desc) == final_lens{12, 8, 8, 2, 2}); + EXPECT(get_all_lens(desc) == all_lens{{3, 4}, {8}, {8}, {2}, {2}}); + EXPECT(desc.generate() == ops{make_op("reshape", {{"dims", {3, 4, 8, 8, 2, 2}}}), + make_op("reshape", {{"dims", {12, 8, 8, 2, 2}}})}); + } + + { + auto desc = base_desc.rebase({3, 5, 64, 1}); + EXPECT(get_final_lens(desc) == final_lens{15, 8, 8, 2, 2}); + EXPECT(get_all_lens(desc) == all_lens{{3, 5}, {8}, {8}, {2}, {2}}); + EXPECT(desc.generate() == ops{make_op("reshape", {{"dims", {3, 5, 8, 8, 1, 1}}}), + make_op("reshape", {{"dims", {15, 8, 8, 1, 1}}}), + make_op("multibroadcast", {{"out_lens", {15, 8, 8, 2, 2}}})}); + } + + { + auto desc = base_desc.rebase({3, 4, 1, 1}); + EXPECT(get_final_lens(desc) == final_lens{12, 1, 1, 2, 2}); + EXPECT(get_all_lens(desc) == all_lens{{3, 4}, {1}, {1}, {2}, {2}}); + EXPECT(desc.generate() == ops{make_op("unsqueeze", {{"axes", {3, 5}}}), + make_op("reshape", {{"dims", {12, 1, 1, 1, 1}}}), + make_op("multibroadcast", {{"out_lens", {12, 1, 1, 2, 2}}})}); + } +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/simplify_algebra_test.cpp b/test/simplify_algebra_test.cpp index e6ce2780dd4..f0eaa8c4b0f 100644 --- a/test/simplify_algebra_test.cpp +++ b/test/simplify_algebra_test.cpp @@ -1036,6 +1036,32 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2) m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2); } +TEST_CASE(simplify_concat_unpack_int4) +{ + auto s = migraphx::shape{migraphx::shape::int8_type, {11008, 2048}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto unpack1 = m1.add_instruction(migraphx::make_op("unpack_int4"), x); + auto unpack2 = m1.add_instruction(migraphx::make_op("unpack_int4"), y); + auto concat = + m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), unpack1, unpack2); + m1.add_return({concat}); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y); + auto unpack = m2.add_instruction(migraphx::make_op("unpack_int4"), concat); + m2.add_return({unpack}); + } + EXPECT(m1 == m2); +} + TEST_CASE(simplify_concat_add_relu) { auto s = migraphx::shape{migraphx::shape::int32_type, {1}}; @@ -1218,6 +1244,37 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis) EXPECT(m1 == m2); } +TEST_CASE(simplify_concat_clip) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {1}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto min = m1.add_literal({s, {0}}); + auto max = m1.add_literal({s, {10}}); + auto clip1 = m1.add_instruction(migraphx::make_op("clip"), x, min, max); + auto clip2 = m1.add_instruction(migraphx::make_op("clip"), y, min, max); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), clip1, clip2); + m1.add_instruction(pass_op{}, concat); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto min = m2.add_literal({s, {0}}); + auto max = m2.add_literal({s, {10}}); + auto concat1 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y); + auto concat2 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), min, min); + auto concat3 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), max, max); + auto clip = m2.add_instruction(migraphx::make_op("clip"), concat1, concat2, concat3); + m2.add_instruction(pass_op{}, clip); + } + EXPECT(m1 == m2); +} + TEST_CASE(concat_convert_fusion) { auto s = migraphx::shape{migraphx::shape::float_type, {64}}; @@ -4094,6 +4151,72 @@ TEST_CASE(mul_dot_b_not_k_broadcast) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(mul_dot_a_int4_dq) +{ + migraphx::shape as{migraphx::shape::float_type, {1, 32, 4096}}; + migraphx::shape bs{migraphx::shape::int8_type, {22016, 2048}}; + migraphx::shape cs{migraphx::shape::float_type, {22016, 4096}}; + migraphx::module m1; + { + auto a = m1.add_parameter("input", as); + + auto lit = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4096}})); + auto litb = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", as.lens()}}), lit); + auto mul = m1.add_instruction(migraphx::make_op("mul"), a, litb); + + auto b = m1.add_literal(migraphx::generate_literal(bs)); + auto unpack = m1.add_instruction(migraphx::make_op("unpack_int4"), b); + auto scales = m1.add_literal(migraphx::generate_literal(cs)); + auto dq = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack, scales); + auto unsqueeze = m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), dq); + auto transpose = m1.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsqueeze); + auto dot = m1.add_instruction(migraphx::make_op("dot"), mul, transpose); + m1.add_return({dot}); + }; + migraphx::module m2 = m1; + run_pass(m1); + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(mul_dot_a_int4_dq_concat) +{ + migraphx::shape as{migraphx::shape::float_type, {1, 32, 4096}}; + migraphx::shape bs{migraphx::shape::int8_type, {4096, 5504}}; + migraphx::shape cs{migraphx::shape::float_type, {4096, 11008}}; + migraphx::module m1; + { + auto a = m1.add_parameter("input", as); + + auto lit = + m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4096}})); + auto litb = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", as.lens()}}), lit); + auto mul = m1.add_instruction(migraphx::make_op("mul"), a, litb); + + std::vector concats; + for(int i = 0; i < 2; i++) + { + auto b = m1.add_literal(migraphx::generate_literal(bs)); + auto unpack = m1.add_instruction(migraphx::make_op("unpack_int4"), b); + auto scales = m1.add_literal(migraphx::generate_literal(cs)); + auto dq = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack, scales); + concats.push_back( + m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), dq)); + } + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), concats); + auto dot = m1.add_instruction(migraphx::make_op("dot"), mul, concat); + m1.add_return({dot}); + }; + migraphx::module m2 = m1; + run_pass(m1); + + EXPECT(m1.sort() == m2.sort()); +} + TEST_CASE(dot_mul_a) { migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}}; diff --git a/test/simplify_qdq_test.cpp b/test/simplify_qdq_test.cpp index 61ec6fa0ae9..c3c50cb4172 100644 --- a/test/simplify_qdq_test.cpp +++ b/test/simplify_qdq_test.cpp @@ -41,7 +41,10 @@ namespace match = migraphx::match; bool is_convolution(const migraphx::instruction& ins) { return ins.name() == "convolution"; } bool is_dot(const migraphx::instruction& ins) { return ins.name() == "dot"; } -void run_pass(migraphx::module& m) { run_passes(m, {migraphx::simplify_qdq{}}); } +void run_pass(migraphx::module& m) +{ + run_passes(m, {migraphx::simplify_qdq{}, migraphx::dead_code_elimination{}}); +} void run_cse(migraphx::module& m) { run_passes(m, {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}}); @@ -162,7 +165,7 @@ TEST_CASE(qdq_different_scales) auto t2 = m1.add_parameter("t2", sh2); auto scale1 = m1.add_literal(0.5f); auto scale2 = m1.add_literal(0.4f); - auto zero = m1.add_literal(std::int8_t{0}); + auto zero = m1.add_literal(std::int8_t{1}); auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale2, zero); @@ -210,8 +213,7 @@ TEST_CASE(dot) auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); m2.add_return({d3}); } @@ -262,8 +264,7 @@ TEST_CASE(dot_fp16) auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); auto d3h = m2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), d3); m2.add_return({d3h}); @@ -308,8 +309,7 @@ TEST_CASE(dot_multi_scale) auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); auto out_scale = add_scale_mul(m2, scale1, scale2, 0, 1, dot->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); m2.add_return({d3}); } @@ -353,8 +353,7 @@ TEST_CASE(dot_broadcasted) auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_mb); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); m2.add_return({d3}); } @@ -398,8 +397,7 @@ TEST_CASE(dot_transposed) auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_t); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); m2.add_return({d3}); } @@ -441,8 +439,7 @@ TEST_CASE(dot_reshaped) auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_t); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); m2.add_return({d3}); } @@ -496,8 +493,7 @@ TEST_CASE(dot_multi_scale_all_skip_post_dq_ops) auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_mb); auto out_scale = add_scale_mul(m2, scale1, scale2, 2, 3, dot->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); m2.add_return({d3}); } @@ -757,8 +753,7 @@ TEST_CASE(dot_add) auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab); m2.add_return({add}); } @@ -811,13 +806,11 @@ TEST_CASE(dot_add_multiple_dq_use) auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); auto dot_1 = m2.add_instruction(migraphx::make_op("quant_dot"), q1_tmbc, q2); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot_1->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot_1); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot_1, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot_1, out_scale); auto d3_q = add_quantize_op(m2, "quantizelinear", d3, scale, zero); auto dot_2 = m2.add_instruction(migraphx::make_op("quant_dot"), d3_q, q1); auto out_scale_2 = add_scale_mul(m2, scale, scale, 1, 1, dot_2->get_shape().lens()); - auto out_zp_2 = init_zero_point(m2, dot_2); - auto d4 = add_quantize_op(m2, "dequantizelinear", dot_2, out_scale_2, out_zp_2); + auto d4 = add_quantize_op(m2, "dequantizelinear", dot_2, out_scale_2); auto add = m2.add_instruction(migraphx::make_op("add"), d4, t1); m2.add_return({add}); } @@ -868,8 +861,7 @@ TEST_CASE(conv) q1, weights); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens()); - auto out_zp = init_zero_point(m2, c1); - auto d6 = add_quantize_op(m2, "dequantizelinear", c1, out_scale, out_zp); + auto d6 = add_quantize_op(m2, "dequantizelinear", c1, out_scale); m2.add_return({d6}); } @@ -986,8 +978,7 @@ TEST_CASE(conv_multi_scale) q_inp, weights); auto out_scale = add_scale_mul(m2, inp_scale, w_scale, 1, 1, c1->get_shape().lens()); - auto out_zp = init_zero_point(m2, c1); - auto d1 = add_quantize_op(m2, "dequantizelinear", c1, out_scale, out_zp); + auto d1 = add_quantize_op(m2, "dequantizelinear", c1, out_scale); m2.add_return({d1}); } @@ -1027,9 +1018,8 @@ TEST_CASE(conv_multi_scale_unsupported_axis) auto input = m2.add_parameter("input", s7); auto weights = m2.add_parameter("weights", s4); auto scale = m2.add_literal(migraphx::generate_literal(s8, 0)); - auto zero = m2.add_literal(std::int8_t{0}); - auto d1 = add_quantize_op(m2, "dequantizelinear", weights, scale, zero); + auto d1 = add_quantize_op(m2, "dequantizelinear", weights, scale); auto c1 = m2.add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}, {"stride", {1, 1}}, @@ -1085,9 +1075,8 @@ TEST_CASE(conv_bias_add) auto bias = m2.add_parameter("bias", s6); auto scale = m2.add_literal(0.5f); auto zero = m2.add_literal(std::int8_t{0}); - auto zero32 = m2.add_literal(std::int32_t{0}); - auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32); + auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale); auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", {{"padding", {0, 0, 0, 0}}, @@ -1098,8 +1087,7 @@ TEST_CASE(conv_bias_add) q1, weights); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens()); - auto out_zp = init_zero_point(m2, c1); - auto d6 = add_quantize_op(m2, "dequantizelinear", c1, out_scale, out_zp); + auto d6 = add_quantize_op(m2, "dequantizelinear", c1, out_scale); auto b1 = m2.add_instruction( migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); auto a1 = m2.add_instruction(migraphx::make_op("add"), d6, b1); @@ -1176,10 +1164,9 @@ TEST_CASE(conv_pooling_dot) auto input = m2.add_parameter("input", s7); auto scale = m2.add_literal(0.5f); auto zero = m2.add_literal(std::int8_t{0}); - auto zero32 = m2.add_literal(std::int32_t{0}); - auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32); - auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero); + auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale); + auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale); auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", {{"padding", {0, 0, 0, 0}}, @@ -1190,8 +1177,7 @@ TEST_CASE(conv_pooling_dot) q1, weights); auto out_scale1 = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens()); - auto out_zp1 = init_zero_point(m2, c1); - auto d5 = add_quantize_op(m2, "dequantizelinear", c1, out_scale1, out_zp1); + auto d5 = add_quantize_op(m2, "dequantizelinear", c1, out_scale1); auto bc1 = m2.add_instruction( migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1); @@ -1208,8 +1194,7 @@ TEST_CASE(conv_pooling_dot) auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero); auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db); auto out_scale2 = add_scale_mul(m2, scale, scale, 1, 0, dot->get_shape().lens()); - auto out_zp2 = init_zero_point(m2, dot); - auto d9 = add_quantize_op(m2, "dequantizelinear", dot, out_scale2, out_zp2); + auto d9 = add_quantize_op(m2, "dequantizelinear", dot, out_scale2); auto mb1 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3); auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1); @@ -1517,22 +1502,21 @@ TEST_CASE(dot_reused) auto w2 = m2.add_parameter("w2", sh); auto scale = m2.add_literal(0.5f); auto zero = m2.add_literal(std::int8_t{0}); - auto zero2 = m2.add_literal(std::int32_t{0}); auto q1 = add_quantize_op(m2, "quantizelinear", x, scale, zero); auto q2 = add_quantize_op(m2, "quantizelinear", w1, scale, zero); auto dot1 = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); auto out_scale1 = add_scale_mul(m2, scale, scale, 1, 1, sh.lens()); - auto d1 = add_quantize_op(m2, "dequantizelinear", dot1, out_scale1, zero2); + auto d1 = add_quantize_op(m2, "dequantizelinear", dot1, out_scale1); auto add1 = m2.add_instruction(migraphx::make_op("add"), d1, y); auto q3 = add_quantize_op(m2, "quantizelinear", add1, scale, zero); auto q4 = add_quantize_op(m2, "quantizelinear", w2, scale, zero); auto dot2 = m2.add_instruction(migraphx::make_op("quant_dot"), q3, q4); auto out_scale2 = add_scale_mul(m2, scale, scale, 1, 1, sh.lens()); - auto d2 = add_quantize_op(m2, "dequantizelinear", dot2, out_scale2, zero2); - auto d3 = add_quantize_op(m2, "dequantizelinear", q3, q3->inputs()[1], q3->inputs()[2]); + auto d2 = add_quantize_op(m2, "dequantizelinear", dot2, out_scale2); + auto d3 = add_quantize_op(m2, "dequantizelinear", q3, q3->inputs()[1]); auto add2 = m2.add_instruction(migraphx::make_op("add"), d2, d3); m2.add_return({add2}); } diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 5e2b0bf1fc8..3f1f5ebabb6 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -1193,6 +1193,122 @@ TEST_CASE(concat_transpose4) EXPECT(m1 == m); } +TEST_CASE(concat_unsqueeze) +{ + auto s = migraphx::shape{migraphx::shape::float_type, {11008, 4096}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto xunsqueeze = m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), x); + auto yunsqueeze = m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), y); + auto concat = + m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), xunsqueeze, yunsqueeze); + m1.add_return({concat}); + } + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y); + auto unsqueeze = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 22016, 4096}}}), concat); + m2.add_return({unsqueeze}); + } + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(concat_reshape) +{ + auto s = migraphx::shape{migraphx::shape::float_type, {11008, 32, 128}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto xreshape = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {11008, 4096}}}), x); + auto yreshape = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {11008, 4096}}}), y); + auto concat = + m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), xreshape, yreshape); + m1.add_return({concat}); + } + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y); + auto reshape = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {22016, 4096}}}), concat); + m2.add_return({reshape}); + } + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(concat_reshape_change_axis) +{ + auto s = migraphx::shape{migraphx::shape::float_type, {2, 256, 1280}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto xreshape = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 16, 16, 1280}}}), x); + auto yreshape = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 16, 16, 1280}}}), y); + auto concat = + m1.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xreshape, yreshape); + m1.add_return({concat}); + } + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), x, y); + auto reshape = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 16, 16, 2560}}}), concat); + m2.add_return({reshape}); + } + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(concat_reshape_broadcast) +{ + auto s = migraphx::shape{migraphx::shape::float_type, {11008, 32, 1}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto xb = m1.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {11008, 32, 128}}}), x); + auto yb = m1.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {11008, 32, 128}}}), y); + auto xreshape = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {11008, 4096}}}), xb); + auto yreshape = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {11008, 4096}}}), yb); + auto concat = + m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), xreshape, yreshape); + m1.add_return({concat}); + } + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y); + auto broadcast = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {22016, 32, 128}}}), concat); + auto reshape = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {22016, 4096}}}), broadcast); + m2.add_return({reshape}); + } + run_pass(m1); + EXPECT(m1 == m2); +} + TEST_CASE(nested_concat) { migraphx::module m; @@ -2630,4 +2746,26 @@ TEST_CASE(add_transpose) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(flatten) +{ + migraphx::shape s{migraphx::shape::float_type, {4608, 8, 2}}; + + migraphx::module m1; + { + auto inp = m1.add_parameter("input", s); + auto flat = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), inp); + m1.add_return({flat}); + }; + run_pass(m1); + + migraphx::module m2; + { + auto inp = m2.add_parameter("input", s); + auto flat = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {4608, 16}}}), inp); + m2.add_return({flat}); + }; + + EXPECT(m1.sort() == m2.sort()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index e837b577882..9dd15eab203 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -41,6 +41,7 @@ void run_pass(migraphx::program& p) {migraphx::fuse_pointwise{}, migraphx::fuse_reduce{}, migraphx::split_reduce{.split_size = 8192}, + migraphx::fuse_pointwise{.enable_rewrite_broadcasts = true}, migraphx::dead_code_elimination{}}); } diff --git a/test/verify/CMakeLists.txt b/test/verify/CMakeLists.txt index 3e45e9bc3a9..bd57abea883 100644 --- a/test/verify/CMakeLists.txt +++ b/test/verify/CMakeLists.txt @@ -31,7 +31,7 @@ target_link_libraries(test_verify migraphx migraphx_all_targets) target_include_directories(test_verify PUBLIC ../include) rocm_clang_tidy_check(test_verify) -foreach(SECTION general rnn conv gemm) +foreach(SECTION general reduce rnn conv gemm) rocm_add_test(NAME test_verify_${SECTION} COMMAND test_verify ${SECTION}) set_tests_properties(test_verify_${SECTION} PROPERTIES COST 100 diff --git a/test/verify/main.cpp b/test/verify/main.cpp index f2324acd91b..876db639644 100644 --- a/test/verify/main.cpp +++ b/test/verify/main.cpp @@ -99,10 +99,19 @@ int main(int argc, const char* argv[]) "float>", "test_quant_dot_3args_5, " "float>", + "test_batch_quant_dot_1, " + "float>", + "test_quant_dot_3args_4, " + "float>", + "test_quant_dot_3args_5, " + "float>", #else "test_batch_quant_dot_1", "test_quant_dot_3args_4", "test_quant_dot_3args_5", + "test_batch_quant_dot_1", + "test_quant_dot_3args_4", + "test_quant_dot_3args_5", "test_batch_quant_dot_1", "test_quant_dot_3args_4", "test_quant_dot_3args_5", @@ -120,20 +129,27 @@ int main(int argc, const char* argv[]) "test_block_reduce_small<67, migraphx::shape::int8_type>", "test_block_reduce_small<128, migraphx::shape::int8_type>", "test_block_reduce_small<129, migraphx::shape::int8_type>", + // disabled because CPU does eliminate_data_type to float for everything "test_bitwise_and", "test_bitwise_and", - "test_unpack_int4", "test_unpack_int4", "test_unpack_int4", - "test_unpack_int4"}); + "test_unpack_int4", + "test_bit_cast", + "test_bit_cast", + "test_bit_cast", + "test_bit_cast"}); rv.disable_test_for("gpu", { // These passes on MI300 but fails on others, same issue as CPU. "test_batch_quant_dot_1", "test_quant_dot_3args_4", "test_quant_dot_3args_5", + "test_batch_quant_dot_1", + "test_quant_dot_3args_4", + "test_quant_dot_3args_5", "test_batch_quant_dot_1", "test_quant_dot_3args_4", "test_quant_dot_3args_5", diff --git a/test/verify/test_abs.cpp b/test/verify/test_abs.cpp index 6b7e680a91b..0c7d7ef050c 100644 --- a/test/verify/test_abs.cpp +++ b/test/verify/test_abs.cpp @@ -43,5 +43,6 @@ struct test_abs : verify_program> template struct test_abs; template struct test_abs; template struct test_abs; +template struct test_abs; template struct test_abs; template struct test_abs; diff --git a/test/verify/test_acos.cpp b/test/verify/test_acos.cpp index 0e864c126b4..c22357afa77 100644 --- a/test/verify/test_acos.cpp +++ b/test/verify/test_acos.cpp @@ -44,5 +44,6 @@ struct test_acos : verify_program> template struct test_acos; template struct test_acos; template struct test_acos; +template struct test_acos; template struct test_acos; template struct test_acos; diff --git a/test/verify/test_add.cpp b/test/verify/test_add.cpp index d34a4674820..54dbedf23ad 100644 --- a/test/verify/test_add.cpp +++ b/test/verify/test_add.cpp @@ -45,5 +45,6 @@ struct test_add : verify_program> template struct test_add; template struct test_add; template struct test_add; +template struct test_add; template struct test_add; template struct test_add; diff --git a/test/verify/test_add_mixed_layout.cpp b/test/verify/test_add_mixed_layout.cpp index 3920df069d8..c4e94feda45 100644 --- a/test/verify/test_add_mixed_layout.cpp +++ b/test/verify/test_add_mixed_layout.cpp @@ -43,6 +43,5 @@ struct test_add_mixed_layout : verify_program> } }; -template struct test_add_mixed_layout; template struct test_add_mixed_layout; template struct test_add_mixed_layout; diff --git a/test/verify/test_arg_ops.cpp b/test/verify/test_arg_ops.cpp index 06e766be527..674e8cb8969 100644 --- a/test/verify/test_arg_ops.cpp +++ b/test/verify/test_arg_ops.cpp @@ -57,6 +57,8 @@ struct test_arg_ops : verify_programadd_instruction(T{Axis, LastIndex}, param); return p; } + + std::string section() const { return "reduce"; } }; // transpose argmax tests template struct test_arg_ops; @@ -162,318 +164,3 @@ template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; - -// transpose argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// transpose argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// broadcast argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// broadcast argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// slice argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// slice argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// default case, standard shape argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// default case, standard shape argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; - -// transpose argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// transpose argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// broadcast argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// broadcast argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// slice argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// slice argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// default case, standard shape argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// default case, standard shape argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; - -// transpose argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// transpose argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// broadcast argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// broadcast argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// slice argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// slice argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// default case, standard shape argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// default case, standard shape argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; diff --git a/test/verify/test_asin.cpp b/test/verify/test_asin.cpp index 5d7f065f2fa..34951182059 100644 --- a/test/verify/test_asin.cpp +++ b/test/verify/test_asin.cpp @@ -44,5 +44,6 @@ struct test_asin : verify_program> template struct test_asin; template struct test_asin; template struct test_asin; +template struct test_asin; template struct test_asin; template struct test_asin; diff --git a/test/verify/test_asinh.cpp b/test/verify/test_asinh.cpp index 3a18015e7a3..4bc6680dd2d 100644 --- a/test/verify/test_asinh.cpp +++ b/test/verify/test_asinh.cpp @@ -44,5 +44,6 @@ struct test_asinh : verify_program> template struct test_asinh; template struct test_asinh; template struct test_asinh; +template struct test_asinh; template struct test_asinh; template struct test_asinh; diff --git a/test/verify/test_atan.cpp b/test/verify/test_atan.cpp index 7d3be8c6a1d..63f8167d000 100644 --- a/test/verify/test_atan.cpp +++ b/test/verify/test_atan.cpp @@ -44,5 +44,6 @@ struct test_atan : verify_program> template struct test_atan; template struct test_atan; template struct test_atan; +template struct test_atan; template struct test_atan; template struct test_atan; diff --git a/test/verify/test_atanh.cpp b/test/verify/test_atanh.cpp index 038a7438bd9..8175f4f001a 100644 --- a/test/verify/test_atanh.cpp +++ b/test/verify/test_atanh.cpp @@ -39,14 +39,15 @@ struct test_atanh : verify_program> migraphx::shape::type_t dtype = migraphx::shape::get_type(); migraphx::shape s{dtype, {16}}; auto x = mm->add_parameter("x", s); - auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {-0.95f}}); - auto max_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {0.95f}}); + auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {-0.875f}}); + auto max_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {0.875f}}); min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val); max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), max_val); auto cx = mm->add_instruction(migraphx::make_op("clip"), x, min_val, max_val); - mm->add_instruction(migraphx::make_op("atanh"), cx); + auto atanh_x = mm->add_instruction(migraphx::make_op("atanh"), cx); + mm->add_return({atanh_x}); return p; } }; @@ -54,5 +55,6 @@ struct test_atanh : verify_program> template struct test_atanh; template struct test_atanh; template struct test_atanh; +template struct test_atanh; template struct test_atanh; template struct test_atanh; diff --git a/test/verify/test_batch_quant_dot_1.cpp b/test/verify/test_batch_quant_dot_1.cpp index 3251ade7bc6..752f896fcf2 100644 --- a/test/verify/test_batch_quant_dot_1.cpp +++ b/test/verify/test_batch_quant_dot_1.cpp @@ -58,5 +58,6 @@ struct test_batch_quant_dot_1 : verify_program; template struct test_batch_quant_dot_1; +template struct test_batch_quant_dot_1; template struct test_batch_quant_dot_1; template struct test_batch_quant_dot_1; diff --git a/test/verify/test_batch_quant_dot_2.cpp b/test/verify/test_batch_quant_dot_2.cpp index a1e38f03840..ea5f3073797 100644 --- a/test/verify/test_batch_quant_dot_2.cpp +++ b/test/verify/test_batch_quant_dot_2.cpp @@ -54,5 +54,6 @@ struct test_batch_quant_dot_2 : verify_program; template struct test_batch_quant_dot_2; +template struct test_batch_quant_dot_2; template struct test_batch_quant_dot_2; template struct test_batch_quant_dot_2; diff --git a/test/verify/test_batch_quant_dot_3.cpp b/test/verify/test_batch_quant_dot_3.cpp index bbeac3af90b..e37e743b8c5 100644 --- a/test/verify/test_batch_quant_dot_3.cpp +++ b/test/verify/test_batch_quant_dot_3.cpp @@ -46,5 +46,6 @@ struct test_batch_quant_dot_3 : verify_program> }; template struct test_batch_quant_dot_3; template struct test_batch_quant_dot_3; +template struct test_batch_quant_dot_3; template struct test_batch_quant_dot_3; template struct test_batch_quant_dot_3; diff --git a/test/verify/test_batch_quant_dot_4.cpp b/test/verify/test_batch_quant_dot_4.cpp index 4763b8fee41..6d7d158baf3 100644 --- a/test/verify/test_batch_quant_dot_4.cpp +++ b/test/verify/test_batch_quant_dot_4.cpp @@ -50,5 +50,6 @@ struct test_batch_quant_dot_4 : verify_program> }; template struct test_batch_quant_dot_4; template struct test_batch_quant_dot_4; +template struct test_batch_quant_dot_4; template struct test_batch_quant_dot_4; template struct test_batch_quant_dot_4; diff --git a/test/verify/test_batch_quant_dot_5.cpp b/test/verify/test_batch_quant_dot_5.cpp index 34cac234565..efc36a34228 100644 --- a/test/verify/test_batch_quant_dot_5.cpp +++ b/test/verify/test_batch_quant_dot_5.cpp @@ -52,5 +52,6 @@ struct test_batch_quant_dot_5 : verify_program> }; template struct test_batch_quant_dot_5; template struct test_batch_quant_dot_5; +template struct test_batch_quant_dot_5; template struct test_batch_quant_dot_5; template struct test_batch_quant_dot_5; diff --git a/test/verify/test_bit_cast.cpp b/test/verify/test_bit_cast.cpp new file mode 100644 index 00000000000..24f9a7fc745 --- /dev/null +++ b/test/verify/test_bit_cast.cpp @@ -0,0 +1,55 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +template +struct test_bit_cast : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{From, {8}}; + auto pa = mm->add_parameter("a", s); + auto pb = mm->add_parameter("b", s); + auto ia = mm->add_instruction( + migraphx::make_op("bit_cast", {{"target_type", migraphx::to_value(To)}}), pa); + auto ib = mm->add_instruction( + migraphx::make_op("bit_cast", {{"target_type", migraphx::to_value(To)}}), pb); + auto ret = mm->add_instruction(migraphx::make_op("add"), ia, ib); + mm->add_return({ret}); + return p; + }; +}; + +template struct test_bit_cast; +template struct test_bit_cast; +template struct test_bit_cast; +template struct test_bit_cast; diff --git a/test/verify/test_block_reduce_small.cpp b/test/verify/test_block_reduce_small.cpp index 0ef4f4bc824..c97e5b7002d 100644 --- a/test/verify/test_block_reduce_small.cpp +++ b/test/verify/test_block_reduce_small.cpp @@ -46,6 +46,8 @@ struct test_block_reduce_small : verify_program> mm->add_return({add}); return p; }; + + std::string section() const { return "reduce"; } }; template diff --git a/test/verify/test_ceil.cpp b/test/verify/test_ceil.cpp index 1bcb10af6e3..d6640165542 100644 --- a/test/verify/test_ceil.cpp +++ b/test/verify/test_ceil.cpp @@ -45,5 +45,6 @@ struct test_ceil : verify_program> template struct test_ceil; template struct test_ceil; template struct test_ceil; +template struct test_ceil; template struct test_ceil; template struct test_ceil; diff --git a/test/verify/test_concat_axis_0.cpp b/test/verify/test_concat_axis_0.cpp index fee3243c989..4d4801bf038 100644 --- a/test/verify/test_concat_axis_0.cpp +++ b/test/verify/test_concat_axis_0.cpp @@ -50,5 +50,6 @@ template struct test_concat_axis_0; template struct test_concat_axis_0; template struct test_concat_axis_0; template struct test_concat_axis_0; +template struct test_concat_axis_0; template struct test_concat_axis_0; template struct test_concat_axis_0; diff --git a/test/verify/test_contiguous.cpp b/test/verify/test_contiguous.cpp index d16907ca64a..603d81b10ee 100644 --- a/test/verify/test_contiguous.cpp +++ b/test/verify/test_contiguous.cpp @@ -46,5 +46,6 @@ struct test_contiguous : verify_program> template struct test_contiguous; template struct test_contiguous; +template struct test_contiguous; template struct test_contiguous; template struct test_contiguous; diff --git a/test/verify/test_conv.cpp b/test/verify/test_conv.cpp index e8ce63e5eed..d75edf4e830 100644 --- a/test/verify/test_conv.cpp +++ b/test/verify/test_conv.cpp @@ -44,5 +44,6 @@ struct test_conv : verify_program> template struct test_conv; template struct test_conv; +template struct test_conv; template struct test_conv; template struct test_conv; diff --git a/test/verify/test_conv2.cpp b/test/verify/test_conv2.cpp index 110f214be1a..eb14a550fab 100644 --- a/test/verify/test_conv2.cpp +++ b/test/verify/test_conv2.cpp @@ -47,5 +47,6 @@ struct test_conv2 : verify_program> }; template struct test_conv2; template struct test_conv2; +template struct test_conv2; template struct test_conv2; template struct test_conv2; diff --git a/test/verify/test_conv_add.cpp b/test/verify/test_conv_add.cpp index 66e1f10a038..2c8739cdc5c 100644 --- a/test/verify/test_conv_add.cpp +++ b/test/verify/test_conv_add.cpp @@ -49,5 +49,6 @@ struct test_conv_add : verify_program> template struct test_conv_add; template struct test_conv_add; +template struct test_conv_add; template struct test_conv_add; template struct test_conv_add; diff --git a/test/verify/test_conv_add_1x1_diff_strides.cpp b/test/verify/test_conv_add_1x1_diff_strides.cpp index 93c5be744da..7faf9840451 100644 --- a/test/verify/test_conv_add_1x1_diff_strides.cpp +++ b/test/verify/test_conv_add_1x1_diff_strides.cpp @@ -55,5 +55,6 @@ struct test_conv_add_1x1_diff_strides : verify_program; template struct test_conv_add_1x1_diff_strides; +template struct test_conv_add_1x1_diff_strides; template struct test_conv_add_1x1_diff_strides; template struct test_conv_add_1x1_diff_strides; diff --git a/test/verify/test_conv_add_relu.cpp b/test/verify/test_conv_add_relu.cpp index 56b15b3db40..9839ac509cc 100644 --- a/test/verify/test_conv_add_relu.cpp +++ b/test/verify/test_conv_add_relu.cpp @@ -53,5 +53,6 @@ struct test_conv_add_relu : verify_program> template struct test_conv_add_relu; template struct test_conv_add_relu; +template struct test_conv_add_relu; template struct test_conv_add_relu; template struct test_conv_add_relu; diff --git a/test/verify/test_conv_add_tune.cpp b/test/verify/test_conv_add_tune.cpp index 27234508707..1a2a0efa326 100644 --- a/test/verify/test_conv_add_tune.cpp +++ b/test/verify/test_conv_add_tune.cpp @@ -74,5 +74,6 @@ struct test_conv_add_tune : verify_program> template struct test_conv_add_tune; template struct test_conv_add_tune; template struct test_conv_add_tune; +template struct test_conv_add_tune; template struct test_conv_add_tune; template struct test_conv_add_tune; diff --git a/test/verify/test_conv_bias_clipped_relu.cpp b/test/verify/test_conv_bias_clipped_relu.cpp index 0ea1cd56110..fa42b428fef 100644 --- a/test/verify/test_conv_bias_clipped_relu.cpp +++ b/test/verify/test_conv_bias_clipped_relu.cpp @@ -59,5 +59,6 @@ struct test_conv_bias_clipped_relu : verify_program; template struct test_conv_bias_clipped_relu; +template struct test_conv_bias_clipped_relu; template struct test_conv_bias_clipped_relu; template struct test_conv_bias_clipped_relu; diff --git a/test/verify/test_conv_bn.cpp b/test/verify/test_conv_bn.cpp index 0f8400c2723..99c1b6f1c0c 100644 --- a/test/verify/test_conv_bn.cpp +++ b/test/verify/test_conv_bn.cpp @@ -87,5 +87,6 @@ struct test_conv_bn : verify_program> template struct test_conv_bn; template struct test_conv_bn; +template struct test_conv_bn; template struct test_conv_bn; template struct test_conv_bn; diff --git a/test/verify/test_conv_bn_add.cpp b/test/verify/test_conv_bn_add.cpp index 0803b060f6b..e8fd3a77de2 100644 --- a/test/verify/test_conv_bn_add.cpp +++ b/test/verify/test_conv_bn_add.cpp @@ -94,5 +94,6 @@ struct test_conv_bn_add : verify_program> template struct test_conv_bn_add; template struct test_conv_bn_add; +template struct test_conv_bn_add; template struct test_conv_bn_add; template struct test_conv_bn_add; diff --git a/test/verify/test_conv_bn_relu_pooling.cpp b/test/verify/test_conv_bn_relu_pooling.cpp index 603f77e8611..5fe0d8fa4a1 100644 --- a/test/verify/test_conv_bn_relu_pooling.cpp +++ b/test/verify/test_conv_bn_relu_pooling.cpp @@ -93,5 +93,6 @@ struct test_conv_bn_relu_pooling : verify_program; template struct test_conv_bn_relu_pooling; +template struct test_conv_bn_relu_pooling; template struct test_conv_bn_relu_pooling; template struct test_conv_bn_relu_pooling; diff --git a/test/verify/test_conv_bn_relu_pooling2.cpp b/test/verify/test_conv_bn_relu_pooling2.cpp index 43ee3e3ecba..d003e146cd0 100644 --- a/test/verify/test_conv_bn_relu_pooling2.cpp +++ b/test/verify/test_conv_bn_relu_pooling2.cpp @@ -109,5 +109,6 @@ struct test_conv_bn_relu_pooling2 : verify_program; template struct test_conv_bn_relu_pooling2; +template struct test_conv_bn_relu_pooling2; template struct test_conv_bn_relu_pooling2; template struct test_conv_bn_relu_pooling2; diff --git a/test/verify/test_conv_pooling.cpp b/test/verify/test_conv_pooling.cpp index 42b9f147cd1..5194fb4977e 100644 --- a/test/verify/test_conv_pooling.cpp +++ b/test/verify/test_conv_pooling.cpp @@ -48,5 +48,6 @@ struct test_conv_pooling : verify_program> template struct test_conv_pooling; template struct test_conv_pooling; +template struct test_conv_pooling; template struct test_conv_pooling; template struct test_conv_pooling; diff --git a/test/verify/test_convert.cpp b/test/verify/test_convert.cpp index 46e8f8c5391..92bcb58dba6 100644 --- a/test/verify/test_convert.cpp +++ b/test/verify/test_convert.cpp @@ -53,5 +53,6 @@ struct test_convert : verify_program> template struct test_convert; template struct test_convert; +template struct test_convert; template struct test_convert; template struct test_convert; diff --git a/test/verify/test_cos.cpp b/test/verify/test_cos.cpp index 693d7e21043..2482f0dad14 100644 --- a/test/verify/test_cos.cpp +++ b/test/verify/test_cos.cpp @@ -44,5 +44,6 @@ struct test_cos : verify_program> template struct test_cos; template struct test_cos; template struct test_cos; +template struct test_cos; template struct test_cos; template struct test_cos; diff --git a/test/verify/test_cosh.cpp b/test/verify/test_cosh.cpp index 038eaadc96c..1e8dd150a59 100644 --- a/test/verify/test_cosh.cpp +++ b/test/verify/test_cosh.cpp @@ -44,5 +44,6 @@ struct test_cosh : verify_program> template struct test_cosh; template struct test_cosh; template struct test_cosh; +template struct test_cosh; template struct test_cosh; template struct test_cosh; diff --git a/test/verify/test_erf.cpp b/test/verify/test_erf.cpp index bdacd622a3d..1edcc199b9b 100644 --- a/test/verify/test_erf.cpp +++ b/test/verify/test_erf.cpp @@ -44,5 +44,6 @@ struct test_erf : verify_program> template struct test_erf; template struct test_erf; template struct test_erf; +template struct test_erf; template struct test_erf; template struct test_erf; diff --git a/test/verify/test_exp.cpp b/test/verify/test_exp.cpp index 66fd7b88ded..c6437603e17 100644 --- a/test/verify/test_exp.cpp +++ b/test/verify/test_exp.cpp @@ -44,5 +44,6 @@ struct test_exp : verify_program> template struct test_exp; template struct test_exp; template struct test_exp; +template struct test_exp; template struct test_exp; template struct test_exp; diff --git a/test/verify/test_floor.cpp b/test/verify/test_floor.cpp index 1c4b9dbd83b..1f5c4baad97 100644 --- a/test/verify/test_floor.cpp +++ b/test/verify/test_floor.cpp @@ -45,5 +45,6 @@ struct test_floor : verify_program> template struct test_floor; template struct test_floor; template struct test_floor; +template struct test_floor; template struct test_floor; template struct test_floor; diff --git a/test/verify/test_fmod_mod.cpp b/test/verify/test_fmod_mod.cpp index 17afc9d9d44..e92cb15406c 100644 --- a/test/verify/test_fmod_mod.cpp +++ b/test/verify/test_fmod_mod.cpp @@ -84,5 +84,6 @@ struct test_mod : verify_program> template struct test_mod; template struct test_mod; template struct test_mod; +template struct test_mod; template struct test_mod; template struct test_mod; diff --git a/test/verify/test_gather.cpp b/test/verify/test_gather.cpp index adb6f4fbac8..19d8e6ae77c 100644 --- a/test/verify/test_gather.cpp +++ b/test/verify/test_gather.cpp @@ -49,11 +49,13 @@ struct test_gather : verify_program> template struct test_gather<0, migraphx::shape::float_type>; template struct test_gather<0, migraphx::shape::half_type>; template struct test_gather<0, migraphx::shape::fp8e4m3fnuz_type>; +template struct test_gather<0, migraphx::shape::fp8e5m2fnuz_type>; template struct test_gather<0, migraphx::shape::fp8e4m3fn_type>; template struct test_gather<0, migraphx::shape::fp8e5m2_type>; // Test Negative axis template struct test_gather<-2, migraphx::shape::float_type>; template struct test_gather<-2, migraphx::shape::half_type>; template struct test_gather<-2, migraphx::shape::fp8e4m3fnuz_type>; +template struct test_gather<-2, migraphx::shape::fp8e5m2fnuz_type>; template struct test_gather<-2, migraphx::shape::fp8e4m3fn_type>; template struct test_gather<-2, migraphx::shape::fp8e5m2_type>; diff --git a/test/verify/test_gathernd_default.cpp b/test/verify/test_gathernd_default.cpp index 38f8afe5bc3..5ff2ec40358 100644 --- a/test/verify/test_gathernd_default.cpp +++ b/test/verify/test_gathernd_default.cpp @@ -46,5 +46,6 @@ struct test_gathernd_default : verify_program> template struct test_gathernd_default; template struct test_gathernd_default; template struct test_gathernd_default; +template struct test_gathernd_default; template struct test_gathernd_default; template struct test_gathernd_default; diff --git a/test/verify/test_gemm.cpp b/test/verify/test_gemm.cpp index 848323ea4f8..c6d93f29ac4 100644 --- a/test/verify/test_gemm.cpp +++ b/test/verify/test_gemm.cpp @@ -45,5 +45,6 @@ struct test_gemm : verify_program> template struct test_gemm; template struct test_gemm; template struct test_gemm; +template struct test_gemm; template struct test_gemm; template struct test_gemm; diff --git a/test/verify/test_gemm_2args_bmv.cpp b/test/verify/test_gemm_2args_bmv.cpp index c4b1c318436..782810fdf79 100644 --- a/test/verify/test_gemm_2args_bmv.cpp +++ b/test/verify/test_gemm_2args_bmv.cpp @@ -52,5 +52,6 @@ struct test_gemm_2args_bmv : verify_program> template struct test_gemm_2args_bmv; template struct test_gemm_2args_bmv; template struct test_gemm_2args_bmv; +template struct test_gemm_2args_bmv; template struct test_gemm_2args_bmv; template struct test_gemm_2args_bmv; diff --git a/test/verify/test_gemm_2args_mm_1.cpp b/test/verify/test_gemm_2args_mm_1.cpp index f1d18d4e966..7ddb70fc300 100644 --- a/test/verify/test_gemm_2args_mm_1.cpp +++ b/test/verify/test_gemm_2args_mm_1.cpp @@ -51,5 +51,6 @@ struct test_gemm_2args_mm_1 : verify_program> template struct test_gemm_2args_mm_1; template struct test_gemm_2args_mm_1; template struct test_gemm_2args_mm_1; +template struct test_gemm_2args_mm_1; template struct test_gemm_2args_mm_1; template struct test_gemm_2args_mm_1; diff --git a/test/verify/test_gemm_2args_mm_2.cpp b/test/verify/test_gemm_2args_mm_2.cpp index 8eb4436d1d6..cd2d0df5d39 100644 --- a/test/verify/test_gemm_2args_mm_2.cpp +++ b/test/verify/test_gemm_2args_mm_2.cpp @@ -52,5 +52,6 @@ struct test_gemm_2args_mm_2 : verify_program> template struct test_gemm_2args_mm_2; template struct test_gemm_2args_mm_2; template struct test_gemm_2args_mm_2; +template struct test_gemm_2args_mm_2; template struct test_gemm_2args_mm_2; template struct test_gemm_2args_mm_2; diff --git a/test/verify/test_gemm_2args_mm_3.cpp b/test/verify/test_gemm_2args_mm_3.cpp index c065890f925..00be7cee4bc 100644 --- a/test/verify/test_gemm_2args_mm_3.cpp +++ b/test/verify/test_gemm_2args_mm_3.cpp @@ -52,5 +52,6 @@ struct test_gemm_2args_mm_3 : verify_program> template struct test_gemm_2args_mm_3; template struct test_gemm_2args_mm_3; template struct test_gemm_2args_mm_3; +template struct test_gemm_2args_mm_3; template struct test_gemm_2args_mm_3; template struct test_gemm_2args_mm_3; diff --git a/test/verify/test_gemm_2args_mm_5.cpp b/test/verify/test_gemm_2args_mm_5.cpp index e379c1c5b2c..45336a1c2f2 100644 --- a/test/verify/test_gemm_2args_mm_5.cpp +++ b/test/verify/test_gemm_2args_mm_5.cpp @@ -51,5 +51,6 @@ struct test_gemm_2args_mm_5 : verify_program> template struct test_gemm_2args_mm_5; template struct test_gemm_2args_mm_5; template struct test_gemm_2args_mm_5; +template struct test_gemm_2args_mm_5; template struct test_gemm_2args_mm_5; template struct test_gemm_2args_mm_5; diff --git a/test/verify/test_gemm_2args_mm_6.cpp b/test/verify/test_gemm_2args_mm_6.cpp index 5d048448a08..e2d00cb169b 100644 --- a/test/verify/test_gemm_2args_mm_6.cpp +++ b/test/verify/test_gemm_2args_mm_6.cpp @@ -54,5 +54,6 @@ struct test_gemm_2args_mm_6 : verify_program> template struct test_gemm_2args_mm_6; template struct test_gemm_2args_mm_6; template struct test_gemm_2args_mm_6; +template struct test_gemm_2args_mm_6; template struct test_gemm_2args_mm_6; template struct test_gemm_2args_mm_6; diff --git a/test/verify/test_gemm_2args_mm_7.cpp b/test/verify/test_gemm_2args_mm_7.cpp index 02bf8c8e5c8..7e03df20abc 100644 --- a/test/verify/test_gemm_2args_mm_7.cpp +++ b/test/verify/test_gemm_2args_mm_7.cpp @@ -51,5 +51,6 @@ struct test_gemm_2args_mm_7 : verify_program> template struct test_gemm_2args_mm_7; template struct test_gemm_2args_mm_7; template struct test_gemm_2args_mm_7; +template struct test_gemm_2args_mm_7; template struct test_gemm_2args_mm_7; template struct test_gemm_2args_mm_7; diff --git a/test/verify/test_gemm_2args_mm_8.cpp b/test/verify/test_gemm_2args_mm_8.cpp index 88cbe2ad80e..175105d1003 100644 --- a/test/verify/test_gemm_2args_mm_8.cpp +++ b/test/verify/test_gemm_2args_mm_8.cpp @@ -51,5 +51,6 @@ struct test_gemm_2args_mm_8 : verify_program> template struct test_gemm_2args_mm_8; // template struct test_gemm_2args_mm_8; // fails with CK, issue#2514 template struct test_gemm_2args_mm_8; +template struct test_gemm_2args_mm_8; template struct test_gemm_2args_mm_8; template struct test_gemm_2args_mm_8; diff --git a/test/verify/test_gemm_2args_mv.cpp b/test/verify/test_gemm_2args_mv.cpp index 94f644e55a4..5a02d7cd765 100644 --- a/test/verify/test_gemm_2args_mv.cpp +++ b/test/verify/test_gemm_2args_mv.cpp @@ -50,5 +50,6 @@ struct test_gemm_2args_mv : verify_program> template struct test_gemm_2args_mv; template struct test_gemm_2args_mv; template struct test_gemm_2args_mv; +template struct test_gemm_2args_mv; template struct test_gemm_2args_mv; template struct test_gemm_2args_mv; diff --git a/test/verify/test_gemm_2args_vbm.cpp b/test/verify/test_gemm_2args_vbm.cpp index d6b8f3b983b..b1c560ee331 100644 --- a/test/verify/test_gemm_2args_vbm.cpp +++ b/test/verify/test_gemm_2args_vbm.cpp @@ -54,5 +54,6 @@ struct test_gemm_2args_vbm : verify_program> template struct test_gemm_2args_vbm; template struct test_gemm_2args_vbm; template struct test_gemm_2args_vbm; +template struct test_gemm_2args_vbm; template struct test_gemm_2args_vbm; template struct test_gemm_2args_vbm; diff --git a/test/verify/test_gemm_2args_vm.cpp b/test/verify/test_gemm_2args_vm.cpp index 48ae7090b4f..5c9b8283a24 100644 --- a/test/verify/test_gemm_2args_vm.cpp +++ b/test/verify/test_gemm_2args_vm.cpp @@ -51,6 +51,7 @@ struct test_gemm_2args_vm : verify_program> template struct test_gemm_2args_vm; template struct test_gemm_2args_vm; template struct test_gemm_2args_vm; +template struct test_gemm_2args_vm; // TODO need hipblaslt support // template struct test_gemm_2args_vm; // template struct test_gemm_2args_vm; diff --git a/test/verify/test_gemm_2args_vv.cpp b/test/verify/test_gemm_2args_vv.cpp index 8a7a0dccfd7..47f9d284df9 100644 --- a/test/verify/test_gemm_2args_vv.cpp +++ b/test/verify/test_gemm_2args_vv.cpp @@ -54,5 +54,6 @@ struct test_gemm_2args_vv : verify_program> template struct test_gemm_2args_vv; template struct test_gemm_2args_vv; template struct test_gemm_2args_vv; +template struct test_gemm_2args_vv; template struct test_gemm_2args_vv; template struct test_gemm_2args_vv; diff --git a/test/verify/test_gemm_add.cpp b/test/verify/test_gemm_add.cpp index 1d577793e9d..cd4231f6871 100644 --- a/test/verify/test_gemm_add.cpp +++ b/test/verify/test_gemm_add.cpp @@ -58,3 +58,6 @@ struct test_gemm_add : verify_program> template struct test_gemm_add; template struct test_gemm_add; // TODO template struct test_gemm_add; +// TODO template struct test_gemm_add; +// TODO template struct test_gemm_add; +// TODO template struct test_gemm_add; diff --git a/test/verify/test_gemm_add_broadcast1.cpp b/test/verify/test_gemm_add_broadcast1.cpp index b363458c503..1106442aba7 100644 --- a/test/verify/test_gemm_add_broadcast1.cpp +++ b/test/verify/test_gemm_add_broadcast1.cpp @@ -54,5 +54,6 @@ struct test_gemm_add_broadcast1 : verify_program template struct test_gemm_add_broadcast1; template struct test_gemm_add_broadcast1; template struct test_gemm_add_broadcast1; +template struct test_gemm_add_broadcast1; template struct test_gemm_add_broadcast1; template struct test_gemm_add_broadcast1; diff --git a/test/verify/test_gemm_add_broadcast2.cpp b/test/verify/test_gemm_add_broadcast2.cpp index 36ec675ee44..88674901d0f 100644 --- a/test/verify/test_gemm_add_broadcast2.cpp +++ b/test/verify/test_gemm_add_broadcast2.cpp @@ -55,5 +55,6 @@ template struct test_gemm_add_broadcast2; // template struct test_gemm_add_broadcast2; // fails with CK, // issue#2514 template struct test_gemm_add_broadcast2; +template struct test_gemm_add_broadcast2; template struct test_gemm_add_broadcast2; template struct test_gemm_add_broadcast2; diff --git a/test/verify/test_gemm_add_broadcast3.cpp b/test/verify/test_gemm_add_broadcast3.cpp new file mode 100644 index 00000000000..fec376c8b94 --- /dev/null +++ b/test/verify/test_gemm_add_broadcast3.cpp @@ -0,0 +1,59 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include +#include + +template +struct test_gemm_add_broadcast3 : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{DType, {1, 2}}; + migraphx::shape m2_shape{DType, {2, 4}}; + migraphx::shape m3_shape{DType, {4}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto l2 = mm->add_parameter("2", m2_shape); + auto l3 = mm->add_parameter("3", m3_shape); + auto l3_b = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3); + + auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2); + mm->add_instruction(migraphx::make_op("add"), l3_b, dot); + return p; + } + std::string section() const { return "gemm"; } +}; + +template struct test_gemm_add_broadcast3; +template struct test_gemm_add_broadcast3; +// template struct test_gemm_add_broadcast3; +// template struct test_gemm_add_broadcast3; +// template struct test_gemm_add_broadcast3; +// template struct test_gemm_add_broadcast3; diff --git a/test/verify/test_gemm_copy.cpp b/test/verify/test_gemm_copy.cpp index 494f5fd0f62..c960c6c17da 100644 --- a/test/verify/test_gemm_copy.cpp +++ b/test/verify/test_gemm_copy.cpp @@ -52,5 +52,6 @@ struct test_gemm_copy : verify_program> template struct test_gemm_copy; template struct test_gemm_copy; template struct test_gemm_copy; +template struct test_gemm_copy; template struct test_gemm_copy; template struct test_gemm_copy; diff --git a/test/verify/test_gemm_ex.cpp b/test/verify/test_gemm_ex.cpp index 111ca3311a8..f269c80f702 100644 --- a/test/verify/test_gemm_ex.cpp +++ b/test/verify/test_gemm_ex.cpp @@ -44,5 +44,6 @@ struct test_gemm_ex : verify_program> template struct test_gemm_ex; template struct test_gemm_ex; template struct test_gemm_ex; +template struct test_gemm_ex; template struct test_gemm_ex; template struct test_gemm_ex; diff --git a/test/verify/test_gemm_literal.cpp b/test/verify/test_gemm_literal.cpp index 18099032eb3..62d32588166 100644 --- a/test/verify/test_gemm_literal.cpp +++ b/test/verify/test_gemm_literal.cpp @@ -49,5 +49,6 @@ struct test_gemm_literal : verify_program> template struct test_gemm_literal; template struct test_gemm_literal; template struct test_gemm_literal; +template struct test_gemm_literal; template struct test_gemm_literal; template struct test_gemm_literal; diff --git a/test/verify/test_gemm_multi_3args.cpp b/test/verify/test_gemm_multi_3args.cpp index f6575ed42f3..f61eee5ed03 100644 --- a/test/verify/test_gemm_multi_3args.cpp +++ b/test/verify/test_gemm_multi_3args.cpp @@ -53,5 +53,6 @@ struct test_gemm_multi_3args : verify_program> template struct test_gemm_multi_3args; template struct test_gemm_multi_3args; template struct test_gemm_multi_3args; +template struct test_gemm_multi_3args; template struct test_gemm_multi_3args; template struct test_gemm_multi_3args; diff --git a/test/verify/test_gemm_multi_3args_alpha0.cpp b/test/verify/test_gemm_multi_3args_alpha0.cpp index 38987f3741a..a66f55a93fc 100644 --- a/test/verify/test_gemm_multi_3args_alpha0.cpp +++ b/test/verify/test_gemm_multi_3args_alpha0.cpp @@ -53,5 +53,6 @@ struct test_gemm_multi_3args_alpha0 : verify_program; template struct test_gemm_multi_3args_alpha0; template struct test_gemm_multi_3args_alpha0; +template struct test_gemm_multi_3args_alpha0; template struct test_gemm_multi_3args_alpha0; template struct test_gemm_multi_3args_alpha0; diff --git a/test/verify/test_gemm_multi_3args_beta0.cpp b/test/verify/test_gemm_multi_3args_beta0.cpp index 4b3853d8b69..6c18d54fdf5 100644 --- a/test/verify/test_gemm_multi_3args_beta0.cpp +++ b/test/verify/test_gemm_multi_3args_beta0.cpp @@ -53,5 +53,6 @@ struct test_gemm_multi_3args_beta0 : verify_program; template struct test_gemm_multi_3args_beta0; template struct test_gemm_multi_3args_beta0; +template struct test_gemm_multi_3args_beta0; template struct test_gemm_multi_3args_beta0; template struct test_gemm_multi_3args_beta0; diff --git a/test/verify/test_gemm_multi_3args_c25.cpp b/test/verify/test_gemm_multi_3args_c25.cpp index f4034098bb0..b5003bfccfd 100644 --- a/test/verify/test_gemm_multi_3args_c25.cpp +++ b/test/verify/test_gemm_multi_3args_c25.cpp @@ -53,5 +53,6 @@ struct test_gemm_multi_3args_c25 : verify_program; template struct test_gemm_multi_3args_c25; template struct test_gemm_multi_3args_c25; +template struct test_gemm_multi_3args_c25; template struct test_gemm_multi_3args_c25; template struct test_gemm_multi_3args_c25; diff --git a/test/verify/test_gemm_multi_dim_2_3.cpp b/test/verify/test_gemm_multi_dim_2_3.cpp index 83c2ff42185..1c9f936d332 100644 --- a/test/verify/test_gemm_multi_dim_2_3.cpp +++ b/test/verify/test_gemm_multi_dim_2_3.cpp @@ -49,5 +49,6 @@ struct test_gemm_multi_dim_2_3 : verify_program> template struct test_gemm_multi_dim_2_3; template struct test_gemm_multi_dim_2_3; template struct test_gemm_multi_dim_2_3; +template struct test_gemm_multi_dim_2_3; template struct test_gemm_multi_dim_2_3; template struct test_gemm_multi_dim_2_3; diff --git a/test/verify/test_gemm_multi_transpose.cpp b/test/verify/test_gemm_multi_transpose.cpp index 5d4a7eb436b..c34cf437d3d 100644 --- a/test/verify/test_gemm_multi_transpose.cpp +++ b/test/verify/test_gemm_multi_transpose.cpp @@ -53,5 +53,6 @@ struct test_gemm_multi_transpose : verify_program; template struct test_gemm_multi_transpose; template struct test_gemm_multi_transpose; +template struct test_gemm_multi_transpose; template struct test_gemm_multi_transpose; template struct test_gemm_multi_transpose; diff --git a/test/verify/test_gemm_multibroadcast.cpp b/test/verify/test_gemm_multibroadcast.cpp index 3be403bfa89..bd771d30e04 100644 --- a/test/verify/test_gemm_multibroadcast.cpp +++ b/test/verify/test_gemm_multibroadcast.cpp @@ -47,5 +47,6 @@ struct test_gemm_multibroadcast : verify_program template struct test_gemm_multibroadcast; template struct test_gemm_multibroadcast; template struct test_gemm_multibroadcast; +template struct test_gemm_multibroadcast; template struct test_gemm_multibroadcast; template struct test_gemm_multibroadcast; diff --git a/test/verify/test_gemm_pointwise.cpp b/test/verify/test_gemm_pointwise.cpp index a9f5be2496d..572ffbfaf27 100644 --- a/test/verify/test_gemm_pointwise.cpp +++ b/test/verify/test_gemm_pointwise.cpp @@ -56,5 +56,6 @@ struct test_gemm_pointwise : verify_program> template struct test_gemm_pointwise; template struct test_gemm_pointwise; template struct test_gemm_pointwise; +template struct test_gemm_pointwise; template struct test_gemm_pointwise; template struct test_gemm_pointwise; diff --git a/test/verify/test_gemm_transpose_add_pooling_sub.cpp b/test/verify/test_gemm_transpose_add_pooling_sub.cpp index 8760538571b..7ca7ba931b9 100644 --- a/test/verify/test_gemm_transpose_add_pooling_sub.cpp +++ b/test/verify/test_gemm_transpose_add_pooling_sub.cpp @@ -66,5 +66,6 @@ struct test_gemm_transpose_add_pooling_sub template struct test_gemm_transpose_add_pooling_sub; template struct test_gemm_transpose_add_pooling_sub; template struct test_gemm_transpose_add_pooling_sub; +template struct test_gemm_transpose_add_pooling_sub; template struct test_gemm_transpose_add_pooling_sub; template struct test_gemm_transpose_add_pooling_sub; diff --git a/test/verify/test_gemm_transposea.cpp b/test/verify/test_gemm_transposea.cpp index 939fa648923..527a537ea0c 100644 --- a/test/verify/test_gemm_transposea.cpp +++ b/test/verify/test_gemm_transposea.cpp @@ -46,6 +46,7 @@ struct test_gemm_transposea : verify_program> template struct test_gemm_transposea; template struct test_gemm_transposea; template struct test_gemm_transposea; +template struct test_gemm_transposea; // TODO need hipblaslt support // template struct test_gemm_transposea; // template struct test_gemm_transposea; diff --git a/test/verify/test_gemm_transposea_ex.cpp b/test/verify/test_gemm_transposea_ex.cpp index 9e107c3d42c..7244fc2c544 100644 --- a/test/verify/test_gemm_transposea_ex.cpp +++ b/test/verify/test_gemm_transposea_ex.cpp @@ -47,5 +47,6 @@ struct test_gemm_transposea_ex : verify_program> template struct test_gemm_transposea_ex; template struct test_gemm_transposea_ex; template struct test_gemm_transposea_ex; +template struct test_gemm_transposea_ex; template struct test_gemm_transposea_ex; template struct test_gemm_transposea_ex; diff --git a/test/verify/test_gemm_transposeab.cpp b/test/verify/test_gemm_transposeab.cpp index 1eaf2c2d360..63baa4903c8 100644 --- a/test/verify/test_gemm_transposeab.cpp +++ b/test/verify/test_gemm_transposeab.cpp @@ -47,5 +47,6 @@ struct test_gemm_transposeab : verify_program> template struct test_gemm_transposeab; template struct test_gemm_transposeab; template struct test_gemm_transposeab; +template struct test_gemm_transposeab; template struct test_gemm_transposeab; template struct test_gemm_transposeab; diff --git a/test/verify/test_gemm_transposeb.cpp b/test/verify/test_gemm_transposeb.cpp index 98060c7614e..fef29432136 100644 --- a/test/verify/test_gemm_transposeb.cpp +++ b/test/verify/test_gemm_transposeb.cpp @@ -46,5 +46,6 @@ struct test_gemm_transposeb : verify_program> template struct test_gemm_transposeb; template struct test_gemm_transposeb; template struct test_gemm_transposeb; +template struct test_gemm_transposeb; template struct test_gemm_transposeb; template struct test_gemm_transposeb; diff --git a/test/verify/test_gemm_transposeb_detect.cpp b/test/verify/test_gemm_transposeb_detect.cpp index 30ace4dee5e..4b83e1c83b7 100644 --- a/test/verify/test_gemm_transposeb_detect.cpp +++ b/test/verify/test_gemm_transposeb_detect.cpp @@ -47,5 +47,6 @@ struct test_gemm_transposeb_detect : verify_program; template struct test_gemm_transposeb_detect; template struct test_gemm_transposeb_detect; +template struct test_gemm_transposeb_detect; template struct test_gemm_transposeb_detect; template struct test_gemm_transposeb_detect; diff --git a/test/verify/test_gemm_transposeb_ex.cpp b/test/verify/test_gemm_transposeb_ex.cpp index 175e57b7600..ef75d796de1 100644 --- a/test/verify/test_gemm_transposeb_ex.cpp +++ b/test/verify/test_gemm_transposeb_ex.cpp @@ -47,5 +47,6 @@ struct test_gemm_transposeb_ex : verify_program> template struct test_gemm_transposeb_ex; template struct test_gemm_transposeb_ex; template struct test_gemm_transposeb_ex; +template struct test_gemm_transposeb_ex; template struct test_gemm_transposeb_ex; template struct test_gemm_transposeb_ex; diff --git a/test/verify/test_instancenorm.cpp b/test/verify/test_instancenorm.cpp index 5f4a551b1a1..77a5541fbae 100644 --- a/test/verify/test_instancenorm.cpp +++ b/test/verify/test_instancenorm.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -75,6 +75,8 @@ struct test_instancenorm : verify_program> add_instancenorm(*mm, x, {1, 2, 1, 1}); return p; } + + std::string section() const { return "reduce"; } }; template struct test_instancenorm; template struct test_instancenorm; @@ -91,6 +93,8 @@ struct test_instancenorm_large_3d : verify_program; diff --git a/test/verify/test_isnan.cpp b/test/verify/test_isnan.cpp index f1701bcd4aa..786fd87f798 100644 --- a/test/verify/test_isnan.cpp +++ b/test/verify/test_isnan.cpp @@ -45,5 +45,6 @@ struct test_isnan : verify_program> template struct test_isnan; template struct test_isnan; template struct test_isnan; +template struct test_isnan; template struct test_isnan; template struct test_isnan; diff --git a/test/verify/test_layernorm.cpp b/test/verify/test_layernorm.cpp index eef8acd5de7..725834bad1e 100644 --- a/test/verify/test_layernorm.cpp +++ b/test/verify/test_layernorm.cpp @@ -40,6 +40,8 @@ struct test_layernorm : verify_program add_layernorm(*mm, x, dims); return p; } + + std::string section() const { return "reduce"; } }; struct test_layernorm2 : verify_program @@ -53,6 +55,8 @@ struct test_layernorm2 : verify_program add_layernorm(*mm, x, dims); return p; } + + std::string section() const { return "reduce"; } }; struct test_layernorm_large : verify_program @@ -66,6 +70,8 @@ struct test_layernorm_large : verify_program add_layernorm(*mm, x, dims); return p; } + + std::string section() const { return "reduce"; } }; struct test_layernorm_fp16 : verify_program @@ -79,6 +85,8 @@ struct test_layernorm_fp16 : verify_program add_layernorm(*mm, x, dims); return p; } + + std::string section() const { return "reduce"; } }; struct test_layernorm_fp8_1 : verify_program @@ -92,6 +100,8 @@ struct test_layernorm_fp8_1 : verify_program add_layernorm(*mm, x, dims); return p; } + + std::string section() const { return "reduce"; } }; struct test_layernorm_fp8_2 : verify_program @@ -101,13 +111,30 @@ struct test_layernorm_fp8_2 : verify_program migraphx::program p; auto* mm = p.get_main_module(); std::vector dims = {1, 24, 64}; - auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fn_type, dims}); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e5m2fnuz_type, dims}); add_layernorm(*mm, x, dims); return p; } + + std::string section() const { return "reduce"; } }; struct test_layernorm_fp8_3 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector dims = {1, 24, 64}; + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fn_type, dims}); + add_layernorm(*mm, x, dims); + return p; + } + + std::string section() const { return "reduce"; } +}; + +struct test_layernorm_fp8_4 : verify_program { migraphx::program create_program() const { @@ -118,6 +145,8 @@ struct test_layernorm_fp8_3 : verify_program add_layernorm(*mm, x, dims); return p; } + + std::string section() const { return "reduce"; } }; struct test_layernorm_eps : verify_program @@ -131,6 +160,8 @@ struct test_layernorm_eps : verify_program add_layernorm(*mm, x, dims, 1e-5f); return p; } + + std::string section() const { return "reduce"; } }; struct test_layernorm_triadd : verify_program @@ -148,6 +179,8 @@ struct test_layernorm_triadd : verify_program add_layernorm(*mm, add2, dims); return p; } + + std::string section() const { return "reduce"; } }; struct test_layernorm_triadd_large : verify_program @@ -165,6 +198,8 @@ struct test_layernorm_triadd_large : verify_program add_layernorm(*mm, add2, dims); return p; } + + std::string section() const { return "reduce"; } }; struct test_add_layernorm_add_gemm_nonstd : verify_program @@ -185,3 +220,18 @@ struct test_add_layernorm_add_gemm_nonstd : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector dims = {1, 9, 6}; + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims}); + add_pointwise_layernorm(*mm, x, dims); + return p; + } + + std::string section() const { return "reduce"; } +}; diff --git a/test/verify/test_literal_limits.cpp b/test/verify/test_literal_limits.cpp index 25cac9107bc..9dac1bb7565 100644 --- a/test/verify/test_literal_limits.cpp +++ b/test/verify/test_literal_limits.cpp @@ -58,5 +58,6 @@ template struct test_literal_limits; template struct test_literal_limits; template struct test_literal_limits; template struct test_literal_limits; +template struct test_literal_limits; template struct test_literal_limits; template struct test_literal_limits; diff --git a/test/verify/test_log.cpp b/test/verify/test_log.cpp index 03027236bc4..366e5c7988a 100644 --- a/test/verify/test_log.cpp +++ b/test/verify/test_log.cpp @@ -44,5 +44,6 @@ struct test_log : verify_program> template struct test_log; template struct test_log; template struct test_log; +template struct test_log; template struct test_log; template struct test_log; diff --git a/test/verify/test_log2.cpp b/test/verify/test_log2.cpp index a28d847891f..bcbf149e4e3 100644 --- a/test/verify/test_log2.cpp +++ b/test/verify/test_log2.cpp @@ -44,5 +44,6 @@ struct test_log2 : verify_program> template struct test_log2; template struct test_log2; template struct test_log2; +template struct test_log2; template struct test_log2; template struct test_log2; diff --git a/test/verify/test_logsoftmax.cpp b/test/verify/test_logsoftmax.cpp index 0ebed420318..afa1ec43aa1 100644 --- a/test/verify/test_logsoftmax.cpp +++ b/test/verify/test_logsoftmax.cpp @@ -40,6 +40,8 @@ struct test_logsoftmax : verify_program> return p; } + + std::string section() const { return "reduce"; } }; template struct test_logsoftmax<0, migraphx::shape::float_type>; @@ -52,17 +54,14 @@ template struct test_logsoftmax<0, migraphx::shape::half_type>; template struct test_logsoftmax<2, migraphx::shape::half_type>; template struct test_logsoftmax<3, migraphx::shape::half_type>; -template struct test_logsoftmax<0, migraphx::shape::fp8e4m3fnuz_type>; template struct test_logsoftmax<1, migraphx::shape::fp8e4m3fnuz_type>; -template struct test_logsoftmax<2, migraphx::shape::fp8e4m3fnuz_type>; template struct test_logsoftmax<3, migraphx::shape::fp8e4m3fnuz_type>; -template struct test_logsoftmax<0, migraphx::shape::fp8e4m3fn_type>; +template struct test_logsoftmax<1, migraphx::shape::fp8e5m2fnuz_type>; +template struct test_logsoftmax<3, migraphx::shape::fp8e5m2fnuz_type>; + template struct test_logsoftmax<1, migraphx::shape::fp8e4m3fn_type>; -template struct test_logsoftmax<2, migraphx::shape::fp8e4m3fn_type>; template struct test_logsoftmax<3, migraphx::shape::fp8e4m3fn_type>; -template struct test_logsoftmax<0, migraphx::shape::fp8e5m2_type>; template struct test_logsoftmax<1, migraphx::shape::fp8e5m2_type>; -template struct test_logsoftmax<2, migraphx::shape::fp8e5m2_type>; template struct test_logsoftmax<3, migraphx::shape::fp8e5m2_type>; diff --git a/test/verify/test_logsoftmax1.cpp b/test/verify/test_logsoftmax1.cpp index a51a83e832b..d6a9daf83b5 100644 --- a/test/verify/test_logsoftmax1.cpp +++ b/test/verify/test_logsoftmax1.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -40,4 +40,6 @@ struct test_logsoftmax1 : verify_program mm->add_return({r}); return p; } + + std::string section() const { return "reduce"; } }; diff --git a/test/verify/test_min_max.cpp b/test/verify/test_min_max.cpp index c81b654d142..12310c2c7a8 100644 --- a/test/verify/test_min_max.cpp +++ b/test/verify/test_min_max.cpp @@ -47,6 +47,7 @@ template struct test_min_max; template struct test_min_max; template struct test_min_max; template struct test_min_max; +template struct test_min_max; template struct test_min_max; template struct test_min_max; @@ -54,5 +55,6 @@ template struct test_min_max; template struct test_min_max; template struct test_min_max; template struct test_min_max; +template struct test_min_max; template struct test_min_max; template struct test_min_max; diff --git a/test/verify/test_mul_dot_a.cpp b/test/verify/test_mul_dot_a.cpp index 92f82ecc0b9..d46a613d646 100644 --- a/test/verify/test_mul_dot_a.cpp +++ b/test/verify/test_mul_dot_a.cpp @@ -52,5 +52,6 @@ struct test_mul_dot_a : verify_program> template struct test_mul_dot_a; template struct test_mul_dot_a; template struct test_mul_dot_a; +template struct test_mul_dot_a; template struct test_mul_dot_a; template struct test_mul_dot_a; diff --git a/test/verify/test_mul_dot_b.cpp b/test/verify/test_mul_dot_b.cpp index 250ddcc957a..6941baa9e36 100644 --- a/test/verify/test_mul_dot_b.cpp +++ b/test/verify/test_mul_dot_b.cpp @@ -53,5 +53,6 @@ struct test_mul_dot_b : verify_program> template struct test_mul_dot_b; template struct test_mul_dot_b; template struct test_mul_dot_b; +template struct test_mul_dot_b; template struct test_mul_dot_b; template struct test_mul_dot_b; diff --git a/test/verify/test_multinomial.cpp b/test/verify/test_multinomial.cpp index d82b7ba7d47..0559d0ec0a5 100644 --- a/test/verify/test_multinomial.cpp +++ b/test/verify/test_multinomial.cpp @@ -64,3 +64,6 @@ template struct test_multinomial; template struct test_multinomial; // TODO This fails, need to figure out why // template struct test_multinomial; +// template struct test_multinomial; +// template struct test_multinomial; +// template struct test_multinomial; diff --git a/test/verify/test_nearbyint.cpp b/test/verify/test_nearbyint.cpp index dbcc5a38767..2f6b3163c4a 100644 --- a/test/verify/test_nearbyint.cpp +++ b/test/verify/test_nearbyint.cpp @@ -47,5 +47,6 @@ struct test_nearbyint : verify_program> template struct test_nearbyint; template struct test_nearbyint; template struct test_nearbyint; +template struct test_nearbyint; template struct test_nearbyint; template struct test_nearbyint; diff --git a/test/verify/test_nonzero.cpp b/test/verify/test_nonzero.cpp index f68d174c463..c689c1e7485 100644 --- a/test/verify/test_nonzero.cpp +++ b/test/verify/test_nonzero.cpp @@ -46,5 +46,6 @@ struct test_nonzero : verify_program> template struct test_nonzero; template struct test_nonzero; template struct test_nonzero; +template struct test_nonzero; template struct test_nonzero; template struct test_nonzero; diff --git a/test/verify/test_pad.cpp b/test/verify/test_pad.cpp index 720bf2e11b2..d4351d32fd4 100644 --- a/test/verify/test_pad.cpp +++ b/test/verify/test_pad.cpp @@ -51,6 +51,3 @@ struct test_pad : verify_program> template struct test_pad; template struct test_pad; template struct test_pad; -template struct test_pad; -template struct test_pad; -template struct test_pad; diff --git a/test/verify/test_pointwise_broadcast_reduce.cpp b/test/verify/test_pointwise_broadcast_reduce.cpp index ffcc658838b..c0269625457 100644 --- a/test/verify/test_pointwise_broadcast_reduce.cpp +++ b/test/verify/test_pointwise_broadcast_reduce.cpp @@ -52,4 +52,6 @@ struct test_pointwise_broadcast_reduce : verify_programadd_return({reshape}); return p; }; + + std::string section() const { return "reduce"; } }; diff --git a/test/verify/test_pointwise_conv_nhwc.cpp b/test/verify/test_pointwise_conv_nhwc.cpp index 16e5baff687..18cff72bf1b 100644 --- a/test/verify/test_pointwise_conv_nhwc.cpp +++ b/test/verify/test_pointwise_conv_nhwc.cpp @@ -53,5 +53,6 @@ struct test_pointwise_conv_nhwc : verify_program template struct test_pointwise_conv_nhwc; template struct test_pointwise_conv_nhwc; +template struct test_pointwise_conv_nhwc; template struct test_pointwise_conv_nhwc; template struct test_pointwise_conv_nhwc; diff --git a/test/verify/test_pow.cpp b/test/verify/test_pow.cpp index 21700d08924..3f1626a08e5 100644 --- a/test/verify/test_pow.cpp +++ b/test/verify/test_pow.cpp @@ -46,5 +46,6 @@ struct test_pow : verify_program> template struct test_pow; template struct test_pow; template struct test_pow; +template struct test_pow; template struct test_pow; template struct test_pow; diff --git a/test/verify/test_prefix_scan_sum_2d.cpp b/test/verify/test_prefix_scan_sum_2d.cpp index c5a8f936877..a7530e08009 100644 --- a/test/verify/test_prefix_scan_sum_2d.cpp +++ b/test/verify/test_prefix_scan_sum_2d.cpp @@ -47,6 +47,7 @@ struct test_prefix_scan_sum_2d_small : verify_program; template struct test_prefix_scan_sum_2d_small; template struct test_prefix_scan_sum_2d_small; +template struct test_prefix_scan_sum_2d_small; template struct test_prefix_scan_sum_2d_small; template struct test_prefix_scan_sum_2d_small; @@ -67,6 +68,3 @@ struct test_prefix_scan_sum_2d_large : verify_program; template struct test_prefix_scan_sum_2d_large; -template struct test_prefix_scan_sum_2d_large; -template struct test_prefix_scan_sum_2d_large; -template struct test_prefix_scan_sum_2d_large; diff --git a/test/verify/test_quant_conv.cpp b/test/verify/test_quant_conv.cpp index c632ef1787f..994b3bdcd51 100644 --- a/test/verify/test_quant_conv.cpp +++ b/test/verify/test_quant_conv.cpp @@ -46,5 +46,6 @@ struct test_quant_conv : verify_program> template struct test_quant_conv; template struct test_quant_conv; +template struct test_quant_conv; template struct test_quant_conv; template struct test_quant_conv; diff --git a/test/verify/test_quant_conv_1.cpp b/test/verify/test_quant_conv_1.cpp index b13d54ddfb0..1c09cf48764 100644 --- a/test/verify/test_quant_conv_1.cpp +++ b/test/verify/test_quant_conv_1.cpp @@ -45,5 +45,6 @@ struct test_quant_conv_1 : verify_program> template struct test_quant_conv_1; template struct test_quant_conv_1; +template struct test_quant_conv_1; template struct test_quant_conv_1; template struct test_quant_conv_1; diff --git a/test/verify/test_quant_conv_2.cpp b/test/verify/test_quant_conv_2.cpp index ad8d4077c22..d518c3bbc6a 100644 --- a/test/verify/test_quant_conv_2.cpp +++ b/test/verify/test_quant_conv_2.cpp @@ -45,5 +45,6 @@ struct test_quant_conv_2 : verify_program> template struct test_quant_conv_2; template struct test_quant_conv_2; +template struct test_quant_conv_2; template struct test_quant_conv_2; template struct test_quant_conv_2; diff --git a/test/verify/test_quant_conv_padding.cpp b/test/verify/test_quant_conv_padding.cpp index 6edbe9ecd72..d6c05e078bf 100644 --- a/test/verify/test_quant_conv_padding.cpp +++ b/test/verify/test_quant_conv_padding.cpp @@ -49,5 +49,6 @@ struct test_quant_conv_padding : verify_program> template struct test_quant_conv_padding; template struct test_quant_conv_padding; +template struct test_quant_conv_padding; template struct test_quant_conv_padding; template struct test_quant_conv_padding; diff --git a/test/verify/test_quant_conv_padding_stride.cpp b/test/verify/test_quant_conv_padding_stride.cpp index d0d05698c68..3c721ed44a2 100644 --- a/test/verify/test_quant_conv_padding_stride.cpp +++ b/test/verify/test_quant_conv_padding_stride.cpp @@ -49,5 +49,6 @@ struct test_quant_conv_padding_stride : verify_program; template struct test_quant_conv_padding_stride; +template struct test_quant_conv_padding_stride; template struct test_quant_conv_padding_stride; template struct test_quant_conv_padding_stride; diff --git a/test/verify/test_quant_dot_3args_1.cpp b/test/verify/test_quant_dot_3args_1.cpp index 3ac166cc8e0..f5d7f40e4fd 100644 --- a/test/verify/test_quant_dot_3args_1.cpp +++ b/test/verify/test_quant_dot_3args_1.cpp @@ -54,5 +54,6 @@ struct test_quant_dot_3args_1 : verify_program; template struct test_quant_dot_3args_1; +template struct test_quant_dot_3args_1; template struct test_quant_dot_3args_1; template struct test_quant_dot_3args_1; diff --git a/test/verify/test_quant_dot_3args_2.cpp b/test/verify/test_quant_dot_3args_2.cpp index cd55c6c046e..c22aa39aebe 100644 --- a/test/verify/test_quant_dot_3args_2.cpp +++ b/test/verify/test_quant_dot_3args_2.cpp @@ -55,5 +55,6 @@ struct test_quant_dot_3args_2 : verify_program; template struct test_quant_dot_3args_2; +template struct test_quant_dot_3args_2; template struct test_quant_dot_3args_2; template struct test_quant_dot_3args_2; diff --git a/test/verify/test_quant_dot_3args_3.cpp b/test/verify/test_quant_dot_3args_3.cpp index 2e9fe4e9584..854ca22fdf5 100644 --- a/test/verify/test_quant_dot_3args_3.cpp +++ b/test/verify/test_quant_dot_3args_3.cpp @@ -54,5 +54,6 @@ struct test_quant_dot_3args_3 : verify_program; template struct test_quant_dot_3args_3; +template struct test_quant_dot_3args_3; template struct test_quant_dot_3args_3; template struct test_quant_dot_3args_3; diff --git a/test/verify/test_quant_dot_3args_4.cpp b/test/verify/test_quant_dot_3args_4.cpp index 344e63e8aa8..abfaf0e10ff 100644 --- a/test/verify/test_quant_dot_3args_4.cpp +++ b/test/verify/test_quant_dot_3args_4.cpp @@ -57,5 +57,6 @@ struct test_quant_dot_3args_4 : verify_program; template struct test_quant_dot_3args_4; +template struct test_quant_dot_3args_4; template struct test_quant_dot_3args_4; template struct test_quant_dot_3args_4; diff --git a/test/verify/test_quant_dot_3args_5.cpp b/test/verify/test_quant_dot_3args_5.cpp index b7197826c73..51b123e9026 100644 --- a/test/verify/test_quant_dot_3args_5.cpp +++ b/test/verify/test_quant_dot_3args_5.cpp @@ -54,5 +54,6 @@ struct test_quant_dot_3args_5 : verify_program; template struct test_quant_dot_3args_5; +template struct test_quant_dot_3args_5; template struct test_quant_dot_3args_5; template struct test_quant_dot_3args_5; diff --git a/test/verify/test_reduce_add.cpp b/test/verify/test_reduce_add.cpp index c2680e1492e..f1d8d6f3fd3 100644 --- a/test/verify/test_reduce_add.cpp +++ b/test/verify/test_reduce_add.cpp @@ -47,9 +47,8 @@ struct test_reduce_add : verify_program> mm->add_return({add}); return p; }; + + std::string section() const { return "reduce"; } }; template struct test_reduce_add; -template struct test_reduce_add; -template struct test_reduce_add; -template struct test_reduce_add; diff --git a/test/verify/test_reduce_mean_bias_half.cpp b/test/verify/test_reduce_mean_bias_half.cpp index c83d7565de7..e70d9dc1dc4 100644 --- a/test/verify/test_reduce_mean_bias_half.cpp +++ b/test/verify/test_reduce_mean_bias_half.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -45,4 +45,6 @@ struct test_reduce_mean_bias_half : verify_program mm->add_return({sqrt}); return p; }; + + std::string section() const { return "reduce"; } }; diff --git a/test/verify/test_reduce_mean_large_half.cpp b/test/verify/test_reduce_mean_large_half.cpp index f8925dbe577..5e40f9e9229 100644 --- a/test/verify/test_reduce_mean_large_half.cpp +++ b/test/verify/test_reduce_mean_large_half.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -38,4 +38,6 @@ struct test_reduce_mean_large_half : verify_program mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), x); return p; }; + + std::string section() const { return "reduce"; } }; diff --git a/test/verify/test_reduce_mean_nhwc.cpp b/test/verify/test_reduce_mean_nhwc.cpp index b1b6e147bb0..275c7d218d3 100644 --- a/test/verify/test_reduce_mean_nhwc.cpp +++ b/test/verify/test_reduce_mean_nhwc.cpp @@ -43,10 +43,9 @@ struct test_reduce_mean_nhwc : verify_program> mm->add_return({sqrt}); return p; }; + + std::string section() const { return "reduce"; } }; template struct test_reduce_mean_nhwc; template struct test_reduce_mean_nhwc; -template struct test_reduce_mean_nhwc; -template struct test_reduce_mean_nhwc; -template struct test_reduce_mean_nhwc; diff --git a/test/verify/test_reduce_mean_reduce_sum.cpp b/test/verify/test_reduce_mean_reduce_sum.cpp index 9dd29f8cffe..b5df0f0147a 100644 --- a/test/verify/test_reduce_mean_reduce_sum.cpp +++ b/test/verify/test_reduce_mean_reduce_sum.cpp @@ -51,4 +51,6 @@ struct test_reduce_mean_reduce_sum : verify_program mm->add_return({mean_div2}); return p; }; + + std::string section() const { return "reduce"; } }; diff --git a/test/verify/test_reduce_mean_variance.cpp b/test/verify/test_reduce_mean_variance.cpp index dccd74a84c3..6843ae7e459 100644 --- a/test/verify/test_reduce_mean_variance.cpp +++ b/test/verify/test_reduce_mean_variance.cpp @@ -45,4 +45,6 @@ struct test_reduce_mean_variance : verify_program mm->add_return({add}); return p; }; + + std::string section() const { return "reduce"; } }; diff --git a/test/verify/test_reduce_noop_add.cpp b/test/verify/test_reduce_noop_add.cpp index 9e642da2f71..c72d7f967f0 100644 --- a/test/verify/test_reduce_noop_add.cpp +++ b/test/verify/test_reduce_noop_add.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -45,4 +45,6 @@ struct test_reduce_noop_add : verify_program mm->add_return({add}); return p; }; + + std::string section() const { return "reduce"; } }; diff --git a/test/verify/test_reduce_op_large.cpp b/test/verify/test_reduce_op_large.cpp index 4a64091bec6..edcd550ef66 100644 --- a/test/verify/test_reduce_op_large.cpp +++ b/test/verify/test_reduce_op_large.cpp @@ -45,6 +45,8 @@ struct test_reduce_op_large : verify_program> mm->add_instruction(Op{{Axis}}, x); return p; }; + + std::string section() const { return "reduce"; } }; template struct test_reduce_op_large; @@ -58,32 +60,13 @@ template struct test_reduce_op_large; template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; template struct test_reduce_op_large; - -template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; + migraphx::shape::fp8e5m2fnuz_type>; template struct test_reduce_op_large; - -template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; template struct test_reduce_op_large; struct test_reduce_mean_1 : verify_program diff --git a/test/verify/test_reduce_op_small.cpp b/test/verify/test_reduce_op_small.cpp index 792029509f2..b1d824f1c8a 100644 --- a/test/verify/test_reduce_op_small.cpp +++ b/test/verify/test_reduce_op_small.cpp @@ -45,6 +45,8 @@ struct test_reduce_op_small : verify_program> mm->add_instruction(Op{{Axis}}, x); return p; }; + + std::string section() const { return "reduce"; } }; template struct test_reduce_op_small; @@ -91,6 +93,31 @@ template struct test_reduce_op_small; +template struct test_reduce_op_small; +template struct test_reduce_op_small; +template struct test_reduce_op_small; +template struct test_reduce_op_small; +template struct test_reduce_op_small; +template struct test_reduce_op_small; +template struct test_reduce_op_small; +template struct test_reduce_op_small; + template struct test_reduce_op_small; template struct test_reduce_op_small; template struct test_reduce_op_small; diff --git a/test/verify/test_reverse.cpp b/test/verify/test_reverse.cpp index bcfa8e218cd..c2c2f037c55 100644 --- a/test/verify/test_reverse.cpp +++ b/test/verify/test_reverse.cpp @@ -44,5 +44,6 @@ struct test_reverse : verify_program> template struct test_reverse; template struct test_reverse; template struct test_reverse; +template struct test_reverse; template struct test_reverse; template struct test_reverse; diff --git a/test/verify/test_rnn_sql_1.cpp b/test/verify/test_rnn_sql_1.cpp index 3621bbd25fa..87218a81b74 100644 --- a/test/verify/test_rnn_sql_1.cpp +++ b/test/verify/test_rnn_sql_1.cpp @@ -85,5 +85,6 @@ struct test_rnn_sql_1 : verify_program> template struct test_rnn_sql_1; template struct test_rnn_sql_1; template struct test_rnn_sql_1; +template struct test_rnn_sql_1; template struct test_rnn_sql_1; template struct test_rnn_sql_1; diff --git a/test/verify/test_roialign.cpp b/test/verify/test_roialign.cpp index 6314491e10d..9f5271f1c5d 100644 --- a/test/verify/test_roialign.cpp +++ b/test/verify/test_roialign.cpp @@ -60,6 +60,3 @@ struct test_roialign : verify_program> template struct test_roialign; template struct test_roialign; -template struct test_roialign; -template struct test_roialign; -template struct test_roialign; diff --git a/test/verify/test_rsqrt.cpp b/test/verify/test_rsqrt.cpp index 097a34f1d32..cc10878788a 100644 --- a/test/verify/test_rsqrt.cpp +++ b/test/verify/test_rsqrt.cpp @@ -55,5 +55,6 @@ struct test_rsqrt : verify_program> template struct test_rsqrt; template struct test_rsqrt; template struct test_rsqrt; +template struct test_rsqrt; template struct test_rsqrt; template struct test_rsqrt; diff --git a/test/verify/test_scatter_elements_none_axis_neg_1.cpp b/test/verify/test_scatter_elements_none_axis_neg_1.cpp index 2791b47c792..354e9f85332 100644 --- a/test/verify/test_scatter_elements_none_axis_neg_1.cpp +++ b/test/verify/test_scatter_elements_none_axis_neg_1.cpp @@ -53,5 +53,6 @@ struct test_scatter_elements_none_axis_neg_1 template struct test_scatter_elements_none_axis_neg_1; template struct test_scatter_elements_none_axis_neg_1; template struct test_scatter_elements_none_axis_neg_1; +template struct test_scatter_elements_none_axis_neg_1; template struct test_scatter_elements_none_axis_neg_1; template struct test_scatter_elements_none_axis_neg_1; diff --git a/test/verify/test_select_module_reduce.cpp b/test/verify/test_select_module_reduce.cpp index aeffd90417e..5e3d8ad1d63 100644 --- a/test/verify/test_select_module_reduce.cpp +++ b/test/verify/test_select_module_reduce.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -66,4 +66,6 @@ struct test_select_module_reduce : verify_program return p; } + + std::string section() const { return "reduce"; } }; diff --git a/test/verify/test_sin.cpp b/test/verify/test_sin.cpp index e2d160eacc7..90d0e6da26b 100644 --- a/test/verify/test_sin.cpp +++ b/test/verify/test_sin.cpp @@ -44,5 +44,6 @@ struct test_sin : verify_program> template struct test_sin; template struct test_sin; template struct test_sin; +template struct test_sin; template struct test_sin; template struct test_sin; diff --git a/test/verify/test_sinh.cpp b/test/verify/test_sinh.cpp index f84cd80ad1a..91bceea24ca 100644 --- a/test/verify/test_sinh.cpp +++ b/test/verify/test_sinh.cpp @@ -44,5 +44,6 @@ struct test_sinh : verify_program> template struct test_sinh; template struct test_sinh; template struct test_sinh; +template struct test_sinh; template struct test_sinh; template struct test_sinh; diff --git a/test/verify/test_softmax.cpp b/test/verify/test_softmax.cpp index 82b65accc9e..3d666a240ad 100644 --- a/test/verify/test_softmax.cpp +++ b/test/verify/test_softmax.cpp @@ -40,6 +40,8 @@ struct test_softmax : verify_program> return p; } + + std::string section() const { return "reduce"; } }; template struct test_softmax<0, migraphx::shape::float_type>; @@ -50,17 +52,14 @@ template struct test_softmax<1, migraphx::shape::half_type>; template struct test_softmax<2, migraphx::shape::half_type>; template struct test_softmax<3, migraphx::shape::half_type>; -template struct test_softmax<0, migraphx::shape::fp8e4m3fnuz_type>; template struct test_softmax<1, migraphx::shape::fp8e4m3fnuz_type>; -template struct test_softmax<2, migraphx::shape::fp8e4m3fnuz_type>; template struct test_softmax<3, migraphx::shape::fp8e4m3fnuz_type>; -template struct test_softmax<0, migraphx::shape::fp8e4m3fn_type>; +template struct test_softmax<1, migraphx::shape::fp8e5m2fnuz_type>; +template struct test_softmax<3, migraphx::shape::fp8e5m2fnuz_type>; + template struct test_softmax<1, migraphx::shape::fp8e4m3fn_type>; -template struct test_softmax<2, migraphx::shape::fp8e4m3fn_type>; template struct test_softmax<3, migraphx::shape::fp8e4m3fn_type>; -template struct test_softmax<0, migraphx::shape::fp8e5m2_type>; template struct test_softmax<1, migraphx::shape::fp8e5m2_type>; -template struct test_softmax<2, migraphx::shape::fp8e5m2_type>; template struct test_softmax<3, migraphx::shape::fp8e5m2_type>; diff --git a/test/verify/test_softmax1.cpp b/test/verify/test_softmax1.cpp index 4742e03607a..ef8c71871bb 100644 --- a/test/verify/test_softmax1.cpp +++ b/test/verify/test_softmax1.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -37,4 +37,6 @@ struct test_softmax1 : verify_program mm->add_instruction(migraphx::make_op("softmax", {{"axis", 0}}), x); return p; } + + std::string section() const { return "reduce"; } }; diff --git a/test/verify/test_softmax2.cpp b/test/verify/test_softmax2.cpp index 804f8f04f0b..c39dcec6c35 100644 --- a/test/verify/test_softmax2.cpp +++ b/test/verify/test_softmax2.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -38,4 +38,6 @@ struct test_softmax2 : verify_program mm->add_instruction(migraphx::make_op("softmax"), x); return p; } + + std::string section() const { return "reduce"; } }; diff --git a/test/verify/test_softmax3.cpp b/test/verify/test_softmax3.cpp index 027c4d64ad3..c8ca6bc95ef 100644 --- a/test/verify/test_softmax3.cpp +++ b/test/verify/test_softmax3.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -41,4 +41,6 @@ struct test_softmax3 : verify_program mm->add_return({r}); return p; } + + std::string section() const { return "reduce"; } }; diff --git a/test/verify/test_softmax4.cpp b/test/verify/test_softmax4.cpp index 73117df57c5..dae139227a8 100644 --- a/test/verify/test_softmax4.cpp +++ b/test/verify/test_softmax4.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -38,4 +38,6 @@ struct test_softmax4 : verify_program mm->add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), x); return p; } + + std::string section() const { return "reduce"; } }; diff --git a/test/verify/test_softmax_large1.cpp b/test/verify/test_softmax_large1.cpp index e8d98194da8..9669b095807 100644 --- a/test/verify/test_softmax_large1.cpp +++ b/test/verify/test_softmax_large1.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -40,4 +40,6 @@ struct test_softmax_large1 : verify_program mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), add); return p; } + + std::string section() const { return "reduce"; } }; diff --git a/test/verify/test_softmax_large2.cpp b/test/verify/test_softmax_large2.cpp index c4311f6f819..12874f647b9 100644 --- a/test/verify/test_softmax_large2.cpp +++ b/test/verify/test_softmax_large2.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -40,4 +40,6 @@ struct test_softmax_large2 : verify_program mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), add); return p; } + + std::string section() const { return "reduce"; } }; diff --git a/test/verify/test_softmax_large3.cpp b/test/verify/test_softmax_large3.cpp index d91bece8048..900d52ca90d 100644 --- a/test/verify/test_softmax_large3.cpp +++ b/test/verify/test_softmax_large3.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -40,4 +40,6 @@ struct test_softmax_large3 : verify_program mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), add); return p; } + + std::string section() const { return "reduce"; } }; diff --git a/test/verify/test_softmaxcrossentropyloss_2d.cpp b/test/verify/test_softmaxcrossentropyloss_2d.cpp index 5839a25020b..684975ea119 100644 --- a/test/verify/test_softmaxcrossentropyloss_2d.cpp +++ b/test/verify/test_softmaxcrossentropyloss_2d.cpp @@ -81,6 +81,8 @@ struct test_softmaxcrossentropyloss_2d return p; } + + std::string section() const { return "reduce"; } }; // template struct test_softmaxcrossentropyloss_2d> template struct test_sqrt; template struct test_sqrt; template struct test_sqrt; +template struct test_sqrt; template struct test_sqrt; template struct test_sqrt; diff --git a/test/verify/test_tan.cpp b/test/verify/test_tan.cpp index 029591da081..b870c1c43d3 100644 --- a/test/verify/test_tan.cpp +++ b/test/verify/test_tan.cpp @@ -44,5 +44,6 @@ struct test_tan : verify_program> template struct test_tan; template struct test_tan; template struct test_tan; +template struct test_tan; template struct test_tan; template struct test_tan; diff --git a/test/verify/test_tanh.cpp b/test/verify/test_tanh.cpp index 94087024d1e..8458ea80a38 100644 --- a/test/verify/test_tanh.cpp +++ b/test/verify/test_tanh.cpp @@ -43,5 +43,6 @@ struct test_tanh : verify_program> template struct test_tanh; template struct test_tanh; template struct test_tanh; +template struct test_tanh; template struct test_tanh; template struct test_tanh; diff --git a/test/verify/test_topk_0.cpp b/test/verify/test_topk_0.cpp index 1cce3dd7775..625e19b6cac 100644 --- a/test/verify/test_topk_0.cpp +++ b/test/verify/test_topk_0.cpp @@ -47,6 +47,3 @@ struct test_topk_0 : verify_program> template struct test_topk_0; template struct test_topk_0; -template struct test_topk_0; -template struct test_topk_0; -template struct test_topk_0; diff --git a/test/verify/test_trans_convert_gemm.cpp b/test/verify/test_trans_convert_gemm.cpp index 187a85e3b47..45701988598 100644 --- a/test/verify/test_trans_convert_gemm.cpp +++ b/test/verify/test_trans_convert_gemm.cpp @@ -48,5 +48,6 @@ struct test_trans_convert_gemm : verify_program> template struct test_trans_convert_gemm; template struct test_trans_convert_gemm; template struct test_trans_convert_gemm; +template struct test_trans_convert_gemm; template struct test_trans_convert_gemm; template struct test_trans_convert_gemm; diff --git a/test/verify/test_transpose_reshape_add_sub_mul.cpp b/test/verify/test_transpose_reshape_add_sub_mul.cpp index c7a0dff59ec..203d8916fd6 100644 --- a/test/verify/test_transpose_reshape_add_sub_mul.cpp +++ b/test/verify/test_transpose_reshape_add_sub_mul.cpp @@ -57,5 +57,6 @@ struct test_transpose_reshape_add_sub_mul template struct test_transpose_reshape_add_sub_mul; template struct test_transpose_reshape_add_sub_mul; template struct test_transpose_reshape_add_sub_mul; +template struct test_transpose_reshape_add_sub_mul; template struct test_transpose_reshape_add_sub_mul; template struct test_transpose_reshape_add_sub_mul; diff --git a/test/verify/test_unbatched_gemm_1.cpp b/test/verify/test_unbatched_gemm_1.cpp index 0dffd9f81ab..05aa138efcf 100644 --- a/test/verify/test_unbatched_gemm_1.cpp +++ b/test/verify/test_unbatched_gemm_1.cpp @@ -63,5 +63,6 @@ struct test_unbatched_gemm_1 : verify_program> template struct test_unbatched_gemm_1; template struct test_unbatched_gemm_1; template struct test_unbatched_gemm_1; +template struct test_unbatched_gemm_1; template struct test_unbatched_gemm_1; template struct test_unbatched_gemm_1; diff --git a/test/verify/test_unbatched_gemm_2.cpp b/test/verify/test_unbatched_gemm_2.cpp index 04496746168..1b9640a694b 100644 --- a/test/verify/test_unbatched_gemm_2.cpp +++ b/test/verify/test_unbatched_gemm_2.cpp @@ -51,5 +51,6 @@ struct test_unbatched_gemm_2 : verify_program> template struct test_unbatched_gemm_2; template struct test_unbatched_gemm_2; template struct test_unbatched_gemm_2; +template struct test_unbatched_gemm_2; template struct test_unbatched_gemm_2; template struct test_unbatched_gemm_2; diff --git a/test/verify/test_where.cpp b/test/verify/test_where.cpp index cf3fdafa748..86c8c2cc942 100644 --- a/test/verify/test_where.cpp +++ b/test/verify/test_where.cpp @@ -49,5 +49,6 @@ struct test_where : verify_program> template struct test_where; template struct test_where; template struct test_where; +template struct test_where; template struct test_where; template struct test_where; diff --git a/tools/api/migraphx.h b/tools/api/migraphx.h index 4ce7ea1d07b..0fed8fd6cca 100644 --- a/tools/api/migraphx.h +++ b/tools/api/migraphx.h @@ -47,7 +47,9 @@ m(uint64_type, uint64_t) \ m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz) \ m(fp8e4m3fn_type, migraphx::fp8::fp8e4m3fn) \ - m(fp8e5m2_type, migraphx::fp8::fp8e5m2) + m(fp8e5m2_type, migraphx::fp8::fp8e5m2) \ + m(bf16_type, bf16) \ + m(fp8e5m2fnuz_type, migraphx::fp8::fp8e5m2fnuz) // clang-format on #ifdef __cplusplus diff --git a/tools/cppcheck/migraphx.py b/tools/cppcheck/migraphx.py index 3be73d5b384..787279044a1 100644 --- a/tools/cppcheck/migraphx.py +++ b/tools/cppcheck/migraphx.py @@ -436,6 +436,8 @@ def MatcherNestedParentheses(cfg, data): for tok2 in token.tokAt(4).forward(token.linkAt(4)): if not simpleMatch(tok2, ") ) ) )"): continue + if simpleMatch(tok2.link.previous, "bind"): + continue cppcheck.reportError( tok2, "style", "Too many nested parentheses can affect readability; consider using variables instead."