diff --git a/.github/workflows/amd-offline-tests.yml b/.github/workflows/amd-offline-tests.yml index 2e75bf418ad9..fd87f8f41150 100644 --- a/.github/workflows/amd-offline-tests.yml +++ b/.github/workflows/amd-offline-tests.yml @@ -48,6 +48,7 @@ jobs: - name: Install Triton run: | cd python + pip3 install ninja # Install in system, because need to override system triton. Otherwise lit tests will use wrong version DEBUG=TRUE TRITON_USE_ROCM=TRUE TRITON_USE_ASSERT_ENABLED_LLVM=TRUE python3 -m pip install --no-build-isolation -vvv -e . diff --git a/.github/workflows/compare-artifacts.yml b/.github/workflows/compare-artifacts.yml new file mode 100644 index 000000000000..e2f050260baf --- /dev/null +++ b/.github/workflows/compare-artifacts.yml @@ -0,0 +1,85 @@ +name: Compare Artifacts +on: + workflow_run: + workflows: + - Integration Tests + types: + - completed + +jobs: + Compare-artifacts: + runs-on: ubuntu-latest + if: ${{ github.event.workflow_run.conclusion == 'success' }} + + steps: + - name: Download PR number artifact + uses: actions/github-script@v6 + with: + script: | + let allArtifacts = await github.rest.actions.listWorkflowRunArtifacts({ + owner: context.repo.owner, + repo: context.repo.repo, + run_id: context.payload.workflow_run.id, + }); + let matchArtifact = allArtifacts.data.artifacts.filter((artifact) => { + return artifact.name == "pr_number" + })[0]; + let download = await github.rest.actions.downloadArtifact({ + owner: context.repo.owner, + repo: context.repo.repo, + artifact_id: matchArtifact.id, + archive_format: 'zip', + }); + let fs = require('fs'); + fs.writeFileSync(`${process.env.GITHUB_WORKSPACE}/pr_number.zip`, Buffer.from(download.data)); + - name: Download comparison result artifact + uses: actions/github-script@v6 + with: + script: | + let allArtifacts = await github.rest.actions.listWorkflowRunArtifacts({ + owner: context.repo.owner, + repo: context.repo.repo, + run_id: context.payload.workflow_run.id, + }); + let matchArtifact = allArtifacts.data.artifacts.filter((artifact) => { + return artifact.name == "comparison_result" + })[0]; + let download = await github.rest.actions.downloadArtifact({ + owner: context.repo.owner, + repo: context.repo.repo, + artifact_id: matchArtifact.id, + archive_format: 'zip', + }); + let fs = require('fs'); + fs.writeFileSync(`${process.env.GITHUB_WORKSPACE}/comparison_result.zip`, Buffer.from(download.data)); + - name: Unzip artifacts + run: | + unzip pr_number.zip + unzip comparison_result.zip + - name: Print artifacts + uses: actions/github-script@v6 + with: + script: | + let fs = require('fs'); + let pr_number = Number(fs.readFileSync('./pr_number')); + let comparison_result = fs.readFileSync('./comparison_result', 'utf8'); + console.log("PR number = ", pr_number); + console.log("Comparison result = ", comparison_result); + - name: Comment on PR + uses: actions/github-script@v6 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + let fs = require('fs'); + let run_id = context.payload.workflow_run.id; + let issue_number = Number(fs.readFileSync('./pr_number')); + let comparison_result = fs.readFileSync('./comparison_result', 'utf8'); + const message = `:warning: **This PR does not produce bitwise identical kernels as the branch it's merged against.** Please check artifacts for details. [Download the output file here](https://github.com/${{ github.repository }}/actions/runs/${run_id}).`; + if (comparison_result.trim() !== 'SUCCESS') { + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: issue_number, + body: message + }); + } diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index e2a2f66c9184..d413a3dca171 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -24,6 +24,7 @@ jobs: run: | pip3 install tabulate pip3 install cmake + pip3 install sphinx #- name: Fetch dependent branches # run: | @@ -33,7 +34,7 @@ jobs: run: | cd docs export PATH=$(python3 -c "import cmake; print(cmake.CMAKE_BIN_DIR)"):$PATH - python3 -m sphinx_multiversion . _build/html/ + python3 -m sphinx . _build/html/main - name: Update docs run: | diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index c44737a312af..b647fdfa73c3 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -27,7 +27,7 @@ jobs: run: | if [ x"${{ github.repository }}" == x"openai/triton" ]; then echo '::set-output name=matrix-required::[["self-hosted", "A100"], ["self-hosted", "H100"]]' - echo '::set-output name=matrix-optional::[["self-hosted", "gfx908"], ["self-hosted", "arc770"]]' + echo '::set-output name=matrix-optional::[]' else echo '::set-output name=matrix-required::["ubuntu-latest"]' echo '::set-output name=matrix-optional::["ubuntu-latest"]' @@ -50,6 +50,9 @@ jobs: if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}} run: | echo "BACKEND=CUDA" >> "${GITHUB_ENV}" + echo "ENABLE_TMA=0" >> "${GITHUB_ENV}" + echo "ENABLE_MMA_V3=0" >> "${GITHUB_ENV}" + echo "TRITON_DISABLE_LINE_INFO=1" >> "${GITHUB_ENV}" - name: Clear cache run: | @@ -59,12 +62,18 @@ jobs: run: | echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}" + - name: Check pre-commit + run: | + python3 -m pip install --upgrade pre-commit + python3 -m pre_commit run --all-files --verbose + - name: Install Triton if: ${{ env.BACKEND == 'CUDA'}} run: | cd python python3 -m pip install --upgrade pip python3 -m pip install cmake==3.24 + python3 -m pip install ninja python3 -m pip install --no-build-isolation -vvv '.[tests]' python3 -m pip install pytest-xdist @@ -79,19 +88,53 @@ jobs: fi lit -v "${LIT_TEST_DIR}" - - name: Run python tests on CUDA - if: ${{ env.BACKEND == 'CUDA'}} + - name: Enable MMAV3 and TMA + if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'H100')}} + run: | + echo "ENABLE_TMA=1" >> "${GITHUB_ENV}" + echo "ENABLE_MMA_V3=1" >> "${GITHUB_ENV}" + + - name: Run python tests on CUDA with ENABLE_TMA=1 and ENABLE_MMA_V3=1 + if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1' && env.ENABLE_MMA_V3 == '1'}} run: | cd python/test/unit - python3 -m pytest -n 8 --ignore=runtime + python3 -m pytest -n 8 --ignore=runtime --ignore=operators --ignore=language/test_line_info.py # run runtime tests serially to avoid race condition with cache handling. python3 -m pytest runtime/ + # run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 + TRITON_DISABLE_LINE_INFO=0 python3 -m pytest language/test_line_info.py + + - name: Run python tests on CUDA with ENABLE_TMA=0 and ENABLE_MMA_V3=0 + if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0' && env.ENABLE_MMA_V3 == '0'}} + run: | + cd python/test/unit + python3 -m pytest -n 8 --ignore=runtime --ignore=hopper --ignore=operators --ignore=language/test_line_info.py + # run runtime tests serially to avoid race condition with cache handling. + python3 -m pytest runtime/ + # run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 + TRITON_DISABLE_LINE_INFO=0 python3 -m pytest language/test_line_info.py + + - name: Clear cache + run: | + rm -rf ~/.triton + + - name: Run partial tests on CUDA with ENABLE_TMA=1 and ENABLE_MMA_V3=1 + if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1' && env.ENABLE_MMA_V3 == '1'}} + run: | + cd python/test/unit + python3 -m pytest -n 8 operators + + - name: Run partial tests on CUDA with ENABLE_TMA=0 and ENABLE_MMA_V3=0 + if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0' && env.ENABLE_MMA_V3 == '0'}} + run: | + cd python/test/unit + python3 -m pytest -n 8 operators - name: Create artifacts archive if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}} run: | cd ~/.triton - tar -czvf artifacts.tar.gz cache + tar -czf artifacts.tar.gz cache - name: Upload artifacts archive if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}} @@ -119,6 +162,7 @@ jobs: Integration-Tests-Third-Party: needs: Runner-Preparation + if: false runs-on: ${{ matrix.runner }} @@ -218,10 +262,22 @@ jobs: sudo apt update sudo apt install gh + - name: Save PR number to a file + env: + PR_NUMBER: ${{ github.event.number }} + run: | + echo $PR_NUMBER > pr_number + - name: Upload PR number to artifacts + uses: actions/upload-artifact@v3 + with: + name: pr_number + path: pr_number + - name: Download latest main artifacts env: ARTIFACT_NAME: artifacts A100 ARTIFACT_JOB_NAME: Integration-Tests-Nvidia + MAX_NUM_ACTIONS_PAGES: 30 GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | OWNER_REPO="${{ github.repository }}" @@ -238,18 +294,27 @@ jobs: USER_ID=$(gh api repos/$OWNER_REPO/pulls/$PR_NUMBER --jq '.user.id') echo "USER_ID: $USER_ID" + run_id_found=false page=1 while true; do + if [ "$page" -gt $MAX_NUM_ACTIONS_PAGES ]; then + break + fi + run_id=$(gh api --method GET "repos/$OWNER_REPO/actions/runs?page=$page&per_page=100" | jq --arg branch_name "$BRANCH_NAME" --arg run_name "Integration Tests" --arg user_id "$USER_ID" '.workflow_runs[] | select(.head_branch == $branch_name and .name == $run_name and .actor.id == ($user_id | tonumber))' | jq '.id' | head -1) if [ "$run_id" != "" ]; then echo "First run ID on branch $BRANCH_NAME is: $run_id" WORKFLOW_RUN_ID=$run_id + run_id_found=true break fi ((page++)) done - + if ! $run_id_found; then + echo "No run_id found for PR ${PR_NUMBER}, moving to the next PR." + continue + fi echo "WORKFLOW_RUN_ID: $WORKFLOW_RUN_ID" ARTIFACT_URL=$(gh api repos/$OWNER_REPO/actions/runs/$WORKFLOW_RUN_ID/artifacts | jq --arg artifact_name "$ARTIFACT_NAME" '.artifacts[] | select(.name == $artifact_name).archive_download_url' --raw-output) echo "ARTIFACT_URL: $ARTIFACT_URL" @@ -289,7 +354,7 @@ jobs: - name: Compare artifacts run: | set +e - python3 python/test/tools/compare_files.py --path1 reference --path2 current --kernels python/test/kernel_comparison/kernels.yml + python3 python/test/tools/compare_files.py --path1 reference --path2 current exit_code=$? set -e echo $exit_code @@ -303,34 +368,20 @@ jobs: echo "Error while comparing artifacts" echo "COMPARISON_RESULT=error" >> $GITHUB_ENV fi - echo "COMPARISON_RESULT=env.COMPARISON_RESULT" - - name: Check exit code and handle failure - if: ${{ env.COMPARISON_RESULT == 'error' }} + - name: Check comparison result and write to file run: | - echo "Error while comparing artifacts" - exit 1 - - name: Fetch Run ID - id: get_run_id - run: echo "RUN_ID=${{ github.run_id }}" >> $GITHUB_ENV - + if [ "${{ env.COMPARISON_RESULT }}" = "true" ]; then + echo "SUCCESS" > comparison_result + else + echo "FAILED" > comparison_result + fi + - name: Upload comparison result to artifacts + uses: actions/upload-artifact@v3 + with: + name: comparison_result + path: comparison_result - name: Upload results as artifact uses: actions/upload-artifact@v2 with: name: kernels-reference-check path: kernels_reference_check.txt - - - name: Check output and comment on PR - if: ${{ env.COMPARISON_RESULT == 'false' }} - uses: actions/github-script@v5 - with: - github-token: ${{ secrets.CI_ACCESS_TOKEN }} - script: | - const run_id = ${{ env.RUN_ID }}; - const issue_number = context.payload.pull_request.number; - const message = `:warning: **This PR does not produce bitwise identical kernels as the branch it's merged against.** Please check artifacts for details. [Download the output file here](https://github.com/${{ github.repository }}/actions/runs/${run_id}).`; - await github.rest.issues.createComment({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: issue_number, - body: message - }); diff --git a/.gitignore b/.gitignore index c4b5ecceac20..ef7867cbde86 100644 --- a/.gitignore +++ b/.gitignore @@ -25,9 +25,5 @@ venv.bak/ .idea cmake-build-* -# cache dumps -triton_cache* -log_* - -# -python/triton/third_party/cuda/bin/ptxas +# Third-party binaries +ptxas diff --git a/CMakeLists.txt b/CMakeLists.txt index 085af05c23b5..88aa1ae7c289 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,8 @@ set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") # Force TRITON_USE_ROCM for ROCm support set(TRITON_USE_ROCM ON) +set(ROCM_DEFAULT_DIR "/opt/rocm") +add_definitions( -DROCM_DEFAULT_DIR="${ROCM_DEFAULT_DIR}") # Ensure Python3 vars are set correctly # used conditionally in this file and by lit tests @@ -200,6 +202,10 @@ include_directories(${LLVM_INCLUDE_DIRS}) include_directories(${PROJECT_SOURCE_DIR}/include) include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files +set(ROCM_LIBRARIES + ${CMAKE_CURRENT_SOURCE_DIR}/lib/rocm/libhsa-runtime64.so +) + # link_directories(${LLVM_LIBRARY_DIR}) add_subdirectory(include) add_subdirectory(lib) @@ -218,6 +224,7 @@ if(TRITON_BUILD_PYTHON_MODULE) TritonAnalysis TritonTransforms TritonGPUTransforms + TritonNvidiaGPUTransforms TritonLLVMIR TritonPTX TritonHSACO @@ -238,9 +245,21 @@ if(TRITON_BUILD_PYTHON_MODULE) MLIRIR ) - set(ROCM_LIBRARIES - ${CMAKE_CURRENT_SOURCE_DIR}/lib/rocm/libhsa-runtime64.so - ) + if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/python/triton/third_party/rocm/lib/libhsa-runtime64.so) + set(ROCM_LIBRARIES + ${CMAKE_CURRENT_SOURCE_DIR}/python/triton/third_party/rocm/lib/libhsa-runtime64.so + ) + elseif(EXISTS "$ENV{ROCM_PATH}/lib/libhsa-runtime64.so" ) + set(ROCM_LIBRARIES + "$ENV{ROCM_PATH}/lib/libhsa-runtime64.so" + ) + elseif(EXISTS "${ROCM_DEFAULT_DIR}/lib/libhsa-runtime64.so" ) + set(ROCM_LIBRARIES + "${ROCM_DEFAULT_DIR}/lib/libhsa-runtime64.so" + ) + else() + message(STATUS "WARNING: Can't find libhsa-runtime64.so") + endif() if(WIN32) target_link_libraries(triton PRIVATE ${ROCM_LIBRARIES} ${LLVM_LIBRARIES} ${CMAKE_DL_LIBS} diff --git a/README.md b/README.md index 675a41ce4d6e..5f88299c91ec 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,38 @@ [![Wheels](https://github.com/openai/triton/actions/workflows/wheels.yml/badge.svg?branch=release/2.0.x)](https://github.com/openai/triton/actions/workflows/wheels.yml) +We're hiring! If you are interested in working on Triton at OpenAI, we have roles open for [Compiler Engineers](https://openai.com/careers/software-engineer-triton-compiler) and [Kernel Engineers](https://openai.com/careers/kernel-engineer). **`Documentation`** | ------------------- | [![Documentation](https://github.com/openai/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/) +# Triton Developer Conference Registration Open +The Triton Developer Conference will be held in a hybrid mode at the Microsoft Silicon Valley Campus in Mountain View, California. The conference will be held on September 20th from 10am to 4pm, followed by a reception till 5:30 pm. Please use the link below to register to attend either in-person or virtually online. + +Registration Link for Triton Developer Conference is [here](https://forms.office.com/r/m4jQXShDts) + +Tentative Agenda for the conference (subject to change): + +|Time |Title |Speaker +|--------|-------|-------| +|10:00 AM|Welcome|Kevin Scott (Microsoft)| +|10:20 AM|The Triton Compiler: Past, Present and Future|Phil Tillet (OpenAI)| +|11:00 AM|**Break**|| +|11:20 AM|Hopper support in Triton|Gustav Zhu (Nvidia)| +|11:40 AM|Bringing Triton to AMD GPUs|Jason Furmanek, Lixun Zhang (AMD)| +|12:00 PM|Intel XPU Backend for Triton|Eikan Wang (Intel)| +|12:20 PM|Vectorization of Triton Kernels for Qualcomm Hexagon Backend|Javed Absar (Qualcomm)| +|12:30 PM|**Lunch**|| +|1:40 PM |Triton for MTIA|Roman Levenstein et al, (Meta)| +|2:00 PM |Using Triton IR for high-performance fusions in XLA|George Karpenkov (Google)| +|2:20 PM |Triton for All: Triton as a device-independent language|Ian Bearman (Microsoft)| +|2:40 PM|**Break**|| +|3:00 PM|PyTorch 2.0 and TorchInductor|Jason Ansel, Horace He (Meta)| +|3:20 PM|Pallas: A JAX Kernel Language|Sharad Vikram (Google)| +|3:40 PM|Writing Grouped GEMMs in Triton|Vinod Grover (Nvidia)| +|4:00 PM|**Reception**|| + # Triton @@ -41,7 +68,7 @@ git checkout triton-mlir # Build ``` cd python -pip3 install cmake; # build time dependency +pip3 install ninja cmake; # build time dependencies pip3 install -e . ``` # Run tests: @@ -60,11 +87,49 @@ lit -v test ``` git clone https://github.com/openai/triton.git; cd triton/python; -pip install cmake; # build-time dependency +pip install ninja cmake; # build-time dependencies pip install -e . ``` +# Building with a custom LLVM + +Triton uses LLVM to generate code for GPUs and CPUs. Normally, the Triton build +downloads a prebuilt LLVM, but you can also build LLVM from source and use that. + +LLVM does not have a stable API, so the Triton build will not work at an +arbitrary LLVM version. + +1. Find the version of LLVM that Triton builds against. Check `python/setup.py` + for a line like + + version = "llvm-17.0.0-c5dede880d17" + + This means that the version of Triton you have builds against + [LLVM](https://github.com/llvm/llvm-project) c5dede880d17. + +2. `git checkout` LLVM at this revision. Optionally, make additional + modifications to LLVM. + +3. [Build LLVM](https://llvm.org/docs/CMake.html). For example, you might run + $ cd $HOME/llvm-project # your clone of LLVM. + $ mkdir build + $ cd build + $ cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm" + $ ninja + +4. Grab a snack, this will take a while. + +5. Build Triton as above, but set the following environment variables. + + # Modify as appropriate to point to your LLVM build. + $ export LLVM_BUILD_DIR=$HOME/llvm-project/build + + $ cd /python + $ LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include \ + LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib \ + LLVM_SYSPATH=$LLVM_BUILD_DIR \ + pip install -e . # Changelog @@ -78,10 +143,6 @@ Version 2.0 is out! New features include: Community contributions are more than welcome, whether it be to fix bugs or to add new features at [github](https://github.com/openai/triton/). For more detailed instructions, please visit our [contributor's guide](CONTRIBUTING.md). -If you’re interested in joining our team and working on Triton & GPU kernels, [we’re hiring](https://openai.com/jobs/#acceleration)! - - - # Compatibility diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 30f26b8ae18e..a8966c5e77f4 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -9,6 +9,7 @@ target_link_libraries(triton-opt PRIVATE TritonAnalysis TritonTransforms TritonGPUTransforms + TritonNvidiaGPUTransforms ${dialect_libs} ${conversion_libs} # tests @@ -29,6 +30,7 @@ target_link_libraries(triton-reduce PRIVATE TritonAnalysis TritonTransforms TritonGPUTransforms + TritonNvidiaGPUTransforms ${dialect_libs} ${conversion_libs} # tests @@ -48,6 +50,7 @@ llvm_update_compile_flags(triton-translate) TritonAnalysis TritonTransforms TritonGPUTransforms + TritonNvidiaGPUTransforms TritonLLVMIR TritonPTX TritonHSACO diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 8083cb49f1c9..5cf1c3a25707 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -1,10 +1,13 @@ #pragma once #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Conversion/NVGPUToLLVM/Passes.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" #include "triton/Conversion/TritonToTritonGPU/Passes.h" @@ -23,15 +26,18 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerAllPasses(); mlir::registerTritonPasses(); mlir::registerTritonGPUPasses(); + mlir::registerTritonNvidiaGPUPasses(); mlir::test::registerTestAliasPass(); mlir::test::registerTestAlignmentPass(); mlir::test::registerTestAllocationPass(); mlir::test::registerTestMembarPass(); mlir::triton::registerConvertTritonToTritonGPUPass(); mlir::triton::registerConvertTritonGPUToLLVMPass(); + mlir::triton::registerConvertNVGPUToLLVMPass(); // TODO: register Triton & TritonGPU passes registry.insert(); diff --git a/bin/triton-translate.cpp b/bin/triton-translate.cpp index 698f032b6e96..3ac0f23fcb5c 100644 --- a/bin/triton-translate.cpp +++ b/bin/triton-translate.cpp @@ -14,6 +14,7 @@ #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Target/HSACO/HSACOTranslation.h" #include "triton/Target/LLVMIR/LLVMIRTranslation.h" #include "triton/Target/PTX/PTXTranslation.h" @@ -39,6 +40,7 @@ OwningOpRef loadMLIRModule(llvm::StringRef inputFilename, mlir::DialectRegistry registry; registry .insert(); context.appendDialectRegistry(registry); @@ -122,12 +124,15 @@ LogicalResult tritonTranslateMain(int argc, char **argv, } llvm::LLVMContext llvmContext; + mlir::triton::gpu::TMAMetadataTy tmaInfos; #ifdef USE_ROCM - auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module, - SMArch.getValue(), true /*isRocm*/); + auto llvmir = + translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue(), + tmaInfos, Target::ROCDL, 0 /*wavesPerEU*/); #else - auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module, - SMArch.getValue(), false /*isRocm*/); + auto llvmir = + translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue(), + tmaInfos, Target::Default, 0 /*wavesPerEU*/); #endif if (!llvmir) { llvm::errs() << "Translate to LLVM IR failed"; diff --git a/docs/conf.py b/docs/conf.py index c64dcffbf988..23ff8ecc9e54 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -101,11 +101,12 @@ def documenter(app, obj, parent): 'gallery_dirs': 'getting-started/tutorials', 'filename_pattern': '', # XXX: Temporarily disable fused attention tutorial on V100 - 'ignore_pattern': r'__init__\.py', + 'ignore_pattern': r'(__init__\.py|09.*\.py|10.*\.py)', 'within_subsection_order': FileNameSortKey, 'reference_url': { 'sphinx_gallery': None, - } + }, + 'abort_on_example_error': True, } # Add any paths that contain templates here, relative to this directory. @@ -144,7 +145,7 @@ def documenter(app, obj, parent): # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = None +language = 'en' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. diff --git a/docs/meetups/08-22-2023.md b/docs/meetups/08-22-2023.md new file mode 100644 index 000000000000..03648402541e --- /dev/null +++ b/docs/meetups/08-22-2023.md @@ -0,0 +1,41 @@ +#### Agenda: + +##### Announcements: +1. Triton conference registration opening soon. Conference on 20th September at the Microsoft Silicon Valley Campus. + +##### Items: +1. H100 updates +2. Triton release plan update +3. Linalg updates +4. Intel GPU Backend status update. +5. Intel working on the CPU backend for Triton. +6. AMD updates +7. Open discussion + +##### Minutes: +Recording link [here](https://drive.google.com/file/d/19Nnc0i7zUyn-ni2RSFHbPHHiPkYU96Mz/view) + +1. H100 updates: + - Preliminary support is merged, disabled by default, can be enabled with env variables + - Supports latest tensor cores, FP8s. Support for Flash Attention on the main branch coming soon. + - Performance is very good on Matmuls, 80-90% of cublas on large Matmuls right now, will eventually reach parity with cublas. Above 600 teraflops on fp16 on xxm card, cublas is 670 on random input data. FP8 is twice that, around 1.2 petaflops. + - Hopper support includes the full FP8 support for compute. +2. Triton release plan update + - No specific dates for now, plan is to release before end of 2023. + - Will move to 3.0 release due to minor backward compatibility breaking changes. For eg. Will move compiler options in the indexing operators as hardcoded operators in the kernel, will bump the major version. + - Functionally the main goal will be to have 3rd party plugins for Intel and AMD gpus. + - May synchronise with a PyTorch release so that PyTorch can benefit from the latest features, however continuous integration workflow is the default release cadence expected. + - Will switch the default behavior to optimized mode for the release, needs more discussion with Nvidia. + - Will expose flags for a user to enable kernel selection themselves. + - Open question: Pytorch hasn’t rebased to latest triton, it is close to PyTorch code freeze – will PyTorch still sync with Triton 2.0? Will we have another release to support triton 2.0? + - Community can start with the latest stable branch and rebase 3rd party plugin on top of that. OAI has no resources to commit to, but community can contribute. +3. Linalg updates + - Discussion on Github for Linalg as a middle layer between the language and target hardware. Includes support for block pointers and modulo operators. + - Please join the conversation [here](https://github.com/openai/triton/discussions/1842) + - Branch pushed is behind the tip, will work on getting it caught up on the tip. +4. Intel GPU Backend status update. + - Please refer to slides [here](https://github.com/openai/triton/blob/main/docs/meetups/Intel%20XPU%20Backend%20for%20Triton%20-%20Update%20-%200823.pptx) +5. Intel working on the CPU backend for Triton. + - Please refer to slides [here](https://github.com/openai/triton/blob/main/docs/meetups/Intel%20XPU%20Backend%20for%20Triton%20-%20Update%20-%200823.pptx) +6. AMD updates + - Please refer to slides [here](https://github.com/openai/triton/blob/main/docs/meetups/Triton_AMD_update_0823.pdf). diff --git a/docs/meetups/Intel XPU Backend for Triton - Update - 0823.pptx b/docs/meetups/Intel XPU Backend for Triton - Update - 0823.pptx new file mode 100644 index 000000000000..d9c61dfaa162 Binary files /dev/null and b/docs/meetups/Intel XPU Backend for Triton - Update - 0823.pptx differ diff --git a/docs/meetups/Triton_AMD_update_0823.pdf b/docs/meetups/Triton_AMD_update_0823.pdf new file mode 100644 index 000000000000..e0178355ca25 Binary files /dev/null and b/docs/meetups/Triton_AMD_update_0823.pdf differ diff --git a/docs/python-api/triton.language.rst b/docs/python-api/triton.language.rst index ab4fd115fb56..410d02faba74 100644 --- a/docs/python-api/triton.language.rst +++ b/docs/python-api/triton.language.rst @@ -192,3 +192,4 @@ Iterators :nosignatures: static_range + multiple_of diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index 9b9d9dace72f..22106f3e167e 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -9,6 +9,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include #include @@ -29,6 +30,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, template class Interval { public: Interval() {} + Interval(T S) : Start(S), End(S+1) {} Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); } T start() const { return Start; } T end() const { return End; } @@ -44,6 +46,16 @@ template class Interval { bool operator<(const Interval &R) const { return std::make_pair(Start, End) < std::make_pair(R.Start, R.End); } + bool adjacent(T Addr) const { + return Addr+1 == Start || Addr == End; + } + bool adjacent(const Interval &R) const { + return adjacent(R.Start) || adjacent(R.End-1); + } + + Interval merge(const Interval &R) const { + return Interval(std::min(Start, R.Start), std::max(End, R.End)); + } private: T Start = std::numeric_limits::min(); @@ -147,17 +159,17 @@ class Allocation { BufferKind kind; BufferId id; size_t size; + size_t alignment; size_t offset; bool operator==(const BufferT &other) const { return id == other.id; } bool operator<(const BufferT &other) const { return id < other.id; } - BufferT() : BufferT(BufferKind::Explicit) {} - BufferT(BufferKind kind) - : kind(kind), id(InvalidBufferId), size(0), offset(0) {} - BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {} - BufferT(BufferKind kind, size_t size, size_t offset) - : kind(kind), id(nextId++), size(size), offset(offset) {} + BufferT() : BufferT(BufferKind::Explicit, 0) {} + BufferT(BufferKind kind, size_t size, size_t alignment = 4, + size_t offset = 0) + : kind(kind), id(nextId++), size(size), alignment(alignment), + offset(offset) {} }; /// Op -> Scratch Buffer diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h index af5b04ad2781..8d28f46aab1d 100644 --- a/include/triton/Analysis/AxisInfo.h +++ b/include/triton/Analysis/AxisInfo.h @@ -21,7 +21,7 @@ namespace mlir { /// This lattice value represents known information on the axes of a lattice. class AxisInfo { public: - typedef SmallVector DimVectorT; + typedef SmallVector DimVectorT; public: /// Default constructor diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index bb3017d5a67c..af0f0961bc9c 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -3,6 +3,7 @@ #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include #include @@ -114,7 +115,7 @@ bool maybeSharedAllocationOp(Operation *op); bool maybeAliasOp(Operation *op); #ifdef USE_ROCM -bool supportMFMA(triton::DotOp op); +bool supportMFMA(triton::DotOp op, int64_t nonKDim); #endif bool supportMMA(triton::DotOp op, int version); @@ -125,7 +126,11 @@ bool isSingleValue(Value value); bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy); -Type getElementType(Value value); +bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy); + +// TODO: Move utility functions that belong to ConvertLayoutOp to class +// ConvertLayoutOpHelper in the future +bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout); template inline SmallVector convertType(ArrayRef in) { @@ -332,6 +337,10 @@ template class CallGraph { FuncDataMapT funcMap; SmallVector roots; }; +// Create a basic DataFlowSolver with constant and dead code analysis included. +std::unique_ptr createDataFlowSolver(); + +triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v); } // namespace mlir diff --git a/include/triton/Conversion/CMakeLists.txt b/include/triton/Conversion/CMakeLists.txt index 143a4375a811..c5dcec8c0e86 100644 --- a/include/triton/Conversion/CMakeLists.txt +++ b/include/triton/Conversion/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(TritonToTritonGPU) add_subdirectory(TritonGPUToLLVM) +add_subdirectory(NVGPUToLLVM) diff --git a/include/triton/Conversion/NVGPUToLLVM/CMakeLists.txt b/include/triton/Conversion/NVGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000000..f89521768f06 --- /dev/null +++ b/include/triton/Conversion/NVGPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name NVGPUToLLVM) +add_public_tablegen_target(NVGPUConversionPassIncGen) diff --git a/include/triton/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.h b/include/triton/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.h new file mode 100644 index 000000000000..e4b91550ca5b --- /dev/null +++ b/include/triton/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.h @@ -0,0 +1,19 @@ +#ifndef TRITON_CONVERSION_NVGPU_TO_LLVM_PASS_H +#define TRITON_CONVERSION_NVGPU_TO_LLVM_PASS_H + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +std::unique_ptr> createConvertNVGPUToLLVMPass(); + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/include/triton/Conversion/NVGPUToLLVM/Passes.h b/include/triton/Conversion/NVGPUToLLVM/Passes.h new file mode 100644 index 000000000000..5e34b7693f6a --- /dev/null +++ b/include/triton/Conversion/NVGPUToLLVM/Passes.h @@ -0,0 +1,16 @@ +#ifndef NVGPU_CONVERSION_PASSES_H +#define NVGPU_CONVERSION_PASSES_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton/Conversion/NVGPUToLLVM/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/include/triton/Conversion/NVGPUToLLVM/Passes.td b/include/triton/Conversion/NVGPUToLLVM/Passes.td new file mode 100644 index 000000000000..364ed4601f94 --- /dev/null +++ b/include/triton/Conversion/NVGPUToLLVM/Passes.td @@ -0,0 +1,20 @@ +#ifndef NVGPU_CONVERSION_PASSES +#define NVGPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + + +def ConvertNVGPUToLLVM : Pass<"convert-nv-gpu-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert NVGPU to LLVM"; + let description = [{ + + }]; + let constructor = "mlir::triton::createConvertNVGPUToLLVMPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::LLVM::LLVMDialect", + "mlir::NVVM::NVVMDialect", + "mlir::triton::nvgpu::NVGPUDialect"]; +} + +#endif diff --git a/include/triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h b/include/triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h index e801196f5800..c3029eb79d18 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h @@ -151,6 +151,12 @@ struct PTXBuilder { // aggressive optimizations that may lead to incorrect results. Operand *newOperand(StringRef constraint, bool init = false); + // Create a new operand that is tied to a previous operand. In this case the + // asm would be permitted to write to an input register. Instead of providing + // constraint code for this operand, the constraint code of the tied operand + // is used. + Operand *newOperand(unsigned operandIndex); + // Create a constant integer operand. Operand *newConstantOperand(int64_t v); // Create a constant operand with explicit code specified. diff --git a/include/triton/Conversion/TritonGPUToLLVM/Passes.td b/include/triton/Conversion/TritonGPUToLLVM/Passes.td index 47873f114be6..f94b8d30ae73 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Passes.td +++ b/include/triton/Conversion/TritonGPUToLLVM/Passes.td @@ -19,6 +19,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" "mlir::tensor::TensorDialect", "mlir::triton::TritonDialect", "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", "mlir::ROCDL::ROCDLDialect", "mlir::NVVM::NVVMDialect"]; @@ -26,9 +27,16 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" Option<"computeCapability", "compute-capability", "int32_t", /*default*/"80", "device compute capability">, - Option<"isROCM", "is-rocm", - "bool", /*default*/"false", - "compile for ROCM-compatible LLVM">, + Option<"tmaMetadata", "tma-metadata", + "mlir::triton::gpu::TMAMetadataTy*", /*default*/"nullptr", + "tma metadata to the runtime">, + Option<"target", "target", "enum Target", "mlir::triton::Target::Default", + "compile for target compatible LLVM", + "llvm::cl::values(" + "clEnumValN(mlir::triton::Target::NVVM, \"nvvm\", \"compile for " + "NVVM-compatible LLVM\"), " + "clEnumValN(mlir::triton::Target::ROCDL, \"rocdl\", \"compile for " + "ROCDL-compatible LLVM\"))">, ]; } diff --git a/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h index 6755e7e3169c..3be5c9009014 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h @@ -3,6 +3,8 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Transforms/DialectConversion.h" +#include "triton/Target/PTX/TmaMetadata.h" + #include namespace mlir { @@ -12,14 +14,15 @@ template class OperationPass; namespace triton { +enum Target { NVVM, ROCDL, Default = NVVM }; + +#define GEN_PASS_DECL +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" + +std::unique_ptr> createConvertTritonGPUToLLVMPass(); std::unique_ptr> -#ifdef USE_ROCM -createConvertTritonGPUToLLVMPass(int computeCapability = 80, - bool isROCM = true); -#else -createConvertTritonGPUToLLVMPass(int computeCapability = 80, - bool isROCM = false); -#endif +createConvertTritonGPUToLLVMPass(const ConvertTritonGPUToLLVMOptions &options); + } // namespace triton } // namespace mlir diff --git a/include/triton/Conversion/TritonToTritonGPU/Passes.h b/include/triton/Conversion/TritonToTritonGPU/Passes.h index e159406b3ed4..cb11537c53af 100644 --- a/include/triton/Conversion/TritonToTritonGPU/Passes.h +++ b/include/triton/Conversion/TritonToTritonGPU/Passes.h @@ -2,6 +2,7 @@ #define TRITON_CONVERSION_PASSES_H #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" +#include "triton/Target/PTX/TmaMetadata.h" namespace mlir { namespace triton { diff --git a/include/triton/Conversion/TritonToTritonGPU/Passes.td b/include/triton/Conversion/TritonToTritonGPU/Passes.td index f32b75c01024..018a4a9e0f3a 100644 --- a/include/triton/Conversion/TritonToTritonGPU/Passes.td +++ b/include/triton/Conversion/TritonToTritonGPU/Passes.td @@ -25,6 +25,12 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO Option<"threadsPerWarp", "threads-per-warp", "int32_t", /*default*/"TRITONGPU_DEFAULT_WARPSIZE", "number of threads per warp">, + Option<"numCTAs", "num-ctas", + "int32_t", /*default*/"1", + "number of ctas in a cga">, + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"80", + "compute capability"> ]; } diff --git a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h index 07e7178cce0b..9ed7fcad27d5 100644 --- a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h +++ b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h @@ -11,6 +11,9 @@ template class OperationPass; namespace triton { constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps"; +constexpr static char AttrNumCTAsName[] = "triton_gpu.num-ctas"; +constexpr static char AttrComputeCapabilityName[] = + "triton_gpu.compute-capability"; constexpr static char AttrNumThreadsPerWarp[] = "triton_gpu.threads-per-warp"; @@ -19,7 +22,8 @@ std::unique_ptr> createConvertTritonToTritonGPUPass(); // Create the pass with numWarps set explicitly. std::unique_ptr> -createConvertTritonToTritonGPUPass(int numWarps, int threadsPerWarp = 32); +createConvertTritonToTritonGPUPass(int numWarps, int threadsPerWarp = 32, + int numCTAs = 1, int computeCapability = 80); } // namespace triton } // namespace mlir diff --git a/include/triton/Dialect/CMakeLists.txt b/include/triton/Dialect/CMakeLists.txt index 27cb65ce5101..02d764601056 100644 --- a/include/triton/Dialect/CMakeLists.txt +++ b/include/triton/Dialect/CMakeLists.txt @@ -1,2 +1,4 @@ add_subdirectory(Triton) add_subdirectory(TritonGPU) +add_subdirectory(TritonNvidiaGPU) +add_subdirectory(NVGPU) diff --git a/include/triton/Dialect/NVGPU/CMakeLists.txt b/include/triton/Dialect/NVGPU/CMakeLists.txt new file mode 100644 index 000000000000..218c20c8819f --- /dev/null +++ b/include/triton/Dialect/NVGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +#add_subdirectory(Transforms) diff --git a/include/triton/Dialect/NVGPU/IR/CMakeLists.txt b/include/triton/Dialect/NVGPU/IR/CMakeLists.txt new file mode 100644 index 000000000000..aa965dac6284 --- /dev/null +++ b/include/triton/Dialect/NVGPU/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +set(LLVM_TARGET_DEFINITIONS NVGPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=nvgpu) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=nvgpu) +mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(NVGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS NVGPUAttrDefs.td) +mlir_tablegen(NVGPUAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(NVGPUAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(NVGPUAttrDefsIncGen) diff --git a/include/triton/Dialect/NVGPU/IR/Dialect.h b/include/triton/Dialect/NVGPU/IR/Dialect.h new file mode 100644 index 000000000000..a27b556fed60 --- /dev/null +++ b/include/triton/Dialect/NVGPU/IR/Dialect.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_DIALECT_NVGPU_IR_DIALECT_H_ +#define TRITON_DIALECT_NVGPU_IR_DIALECT_H_ + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h.inc" +#include "triton/Dialect/NVGPU/IR/OpsEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc" + +#define GET_OP_CLASSES +#include "triton/Dialect/NVGPU/IR/Ops.h.inc" + +namespace mlir { +namespace triton { +namespace nvgpu {} // namespace nvgpu +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ diff --git a/include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td b/include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td new file mode 100644 index 000000000000..20229f1e02c9 --- /dev/null +++ b/include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td @@ -0,0 +1,33 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef NVGPU_ATTRDEFS +#define NVGPU_ATTRDEFS + +include "triton/Dialect/NVGPU/IR/NVGPUDialect.td" +include "mlir/IR/AttrTypeBase.td" + +class NVGPU_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { +} + +#endif diff --git a/include/triton/Dialect/NVGPU/IR/NVGPUDialect.td b/include/triton/Dialect/NVGPU/IR/NVGPUDialect.td new file mode 100644 index 000000000000..6978173d4982 --- /dev/null +++ b/include/triton/Dialect/NVGPU/IR/NVGPUDialect.td @@ -0,0 +1,40 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef NVGPU_DIALECT +#define NVGPU_DIALECT + +include "mlir/IR/OpBase.td" + +def NVGPU_Dialect : Dialect { + let name = "nvgpu"; + let cppNamespace = "::mlir::triton::nvgpu"; + + let description = [{ + NVGPU Dialect. + }]; + + let dependentDialects = [ + "mlir::LLVM::LLVMDialect" + ]; +} + +#endif diff --git a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td new file mode 100644 index 000000000000..a9451984c7ce --- /dev/null +++ b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td @@ -0,0 +1,248 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef NVGPU_OPS +#define NVGPU_OPS + +include "triton/Dialect/NVGPU/IR/NVGPUDialect.td" +include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +def I8Ptr_global : LLVM_IntPtrBase<8, 1>; +def I8Ptr_shared : LLVM_IntPtrBase<8, 3>; +def I64Ptr_shared : LLVM_IntPtrBase<64, 3>; + +class NVGPU_Op traits = []> : + LLVM_OpBase; + +def NVGPU_WGMMAFenceOp : NVGPU_Op<"wgmma_fence", []> { + let assemblyFormat = "attr-dict"; +} + + +def NVGPU_WGMMACommitGroupOp : NVGPU_Op<"wgmma_commit_group", []> { + let assemblyFormat = "attr-dict"; +} + +def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group", []> { + let arguments = (ins I32Attr:$pendings); + let assemblyFormat = "attr-dict"; +} + +def NVGPU_MBarrierInitOp : NVGPU_Op<"mbarrier_init", [MemoryEffects<[MemWrite]>]> { + let arguments = (ins I64Ptr_shared:$mbarrier, I1:$pred, I32Attr:$count); + let assemblyFormat = "$mbarrier `,` $pred attr-dict `:` type($mbarrier)"; +} + +def MBarrier_ArriveTypeAttr : I32EnumAttr<"MBarriveType", + "mbarrier arrive type, either 'normal', 'expect_tx', 'cp_async'", + [ + I32EnumAttrCase<"normal", 0>, + I32EnumAttrCase<"cp_async", 1>, + I32EnumAttrCase<"expect_tx", 2>, + I32EnumAttrCase<"remote", 3>, + ]>{ + let cppNamespace = "::mlir::triton::nvgpu"; +} + +def NVGPU_MBarrierArriveOp : NVGPU_Op<"mbarrier_arrive", []> { + let arguments = (ins I64Ptr_shared:$mbarrier, I1:$pred, Optional:$ctaId, MBarrier_ArriveTypeAttr:$arriveType, DefaultValuedAttr:$txCount); + let assemblyFormat = "$mbarrier `,` $pred (`,` $ctaId^)? attr-dict `:` type($mbarrier)"; +} + +def NVGPU_MBarrierWaitOp : NVGPU_Op<"mbarrier_wait", []> { + let arguments = (ins I64Ptr_shared:$mbarrier, I1:$phase); + let assemblyFormat = "$mbarrier `,` $phase attr-dict `:` type(operands)"; +} + +def NVGPU_NamedBarrierArriveOp : NVGPU_Op<"bar_arrive", []> { + let arguments = (ins I32:$bar, I32:$numThreads); + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + +def NVGPU_NamedBarrierWaitOp : NVGPU_Op<"bar_wait", []> { + let arguments = (ins I32:$bar, I32:$numThreads); + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + +def WGMMADesc_ModeAttr : I32EnumAttr<"WGMMADescMode", + "wgmma desc mode, either 'none', 'swizzle128', 'swizzle64', or 'swizzle32'", + [ + I32EnumAttrCase<"none", 0>, + I32EnumAttrCase<"swizzle128", 1>, + I32EnumAttrCase<"swizzle64", 2>, + I32EnumAttrCase<"swizzle32", 3> + ]>{ + let cppNamespace = "::mlir::triton::nvgpu"; +} + +def NVGPU_WGMMADescCreateOp : NVGPU_Op<"wgmma_desc_create", []> { + let arguments = (ins LLVM_AnyPointer:$buffer, I32:$height, WGMMADesc_ModeAttr:$mode); + let results = (outs I64:$res); + let assemblyFormat = "$buffer `,` $height attr-dict `:` functional-type(operands, results)"; +} + +def NVGPU_TMALoadTiledOp : NVGPU_Op<"tma_load_tiled", [AttrSizedOperandSegments]> { + let arguments = (ins I8Ptr_shared:$dst, I64Ptr_shared:$mbarrier, I8Ptr_global:$tmaDesc, I64:$l2Desc, + I1:$pred, Variadic:$coords, Optional:$mcastMask); + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + +def NVGPU_TMALoadIm2colOp : NVGPU_Op<"tma_load_im2col", []> { + let arguments = (ins I8Ptr_shared:$dst, I64Ptr_shared:$mbarrier, I8Ptr_global:$tmaDesc, I64:$l2Desc, LLVM_AnyStruct:$im2colOffsets, I1:$pred, Variadic:$coords, I16Attr:$mcastMask); + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + +def WGMMA_LayoutAttr : I32EnumAttr<"WGMMALayout", + "wgmma layout, either 'row' or 'col'", + [ + I32EnumAttrCase<"row", 0>, + I32EnumAttrCase<"col", 1> + ]>{ + let cppNamespace = "::mlir::triton::nvgpu"; +} + +def WGMMA_EltTypeAttr : I32EnumAttr<"WGMMAEltType", + "wgmma operand type, either 's8', 's32', 'e4m3', 'e5m2', 'f16', 'bf16', 'tf32', or 'f32'", + [ + I32EnumAttrCase<"s8", 0>, + I32EnumAttrCase<"s32", 1>, + I32EnumAttrCase<"e4m3", 2>, + I32EnumAttrCase<"e5m2", 3>, + I32EnumAttrCase<"f16", 4>, + I32EnumAttrCase<"bf16", 5>, + I32EnumAttrCase<"tf32", 6>, + I32EnumAttrCase<"f32", 7> + ]>{ + let cppNamespace = "::mlir::triton::nvgpu"; +} + +def WGMMA_OperandType : AnyTypeOf<[LLVM_AnyStruct, I64], "wgmma operand A/B type">; + +def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> { + let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, LLVM_AnyStruct:$opC, + I32Attr:$m, I32Attr:$n, I32Attr:$k, + WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB, + WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB); + let results = (outs LLVM_AnyStruct:$res); + let assemblyFormat = "$opA `,` $opB `,` $opC attr-dict `:` functional-type(operands, $res)"; +} + +def NVGPU_CGABarrierSyncOp : NVGPU_Op<"cga_barrier_sync", []> { + let assemblyFormat = "attr-dict"; +} + +def NVGPU_CGABarrierArriveOp : NVGPU_Op<"cga_barrier_arrive", []> { + let assemblyFormat = "attr-dict"; +} + +def NVGPU_CGABarrierWaitOp : NVGPU_Op<"cga_barrier_wait", []> { + let assemblyFormat = "attr-dict"; +} + +def NVGPU_LoadDSmemOp : NVGPU_Op<"load_dsmem", [MemoryEffects<[MemRead]>]> { + let arguments = (ins LLVM_AnyPointer:$addr, I32:$ctaId, I32Attr:$bitwidth, I32Attr:$vec); + let builders = [ + OpBuilder<(ins "Type":$resultTy, "Value":$addr, "Value":$ctaId)>, + OpBuilder<(ins "Value":$addr, "Value":$ctaId, "unsigned":$bitwidth, "unsigned":$vec)>, + OpBuilder<(ins "Value":$addr, "Value":$ctaId, "unsigned":$bitwidth)> + ]; + let results = (outs LLVM_LoadableType:$result); + let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; +} + +def NVGPU_StoreDSmemOp : NVGPU_Op<"store_dsmem", [MemoryEffects<[MemWrite]>]> { + let arguments = (ins LLVM_AnyPointer:$addr, I32:$ctaId, + Variadic:$values, I1:$pred); + let builders = [ + OpBuilder<(ins "Value":$addr, "Value":$ctaId, "Value":$value, "Value":$pred)>, + ]; + let assemblyFormat = "operands attr-dict `:` type(operands)"; + let extraClassDeclaration = [{ + unsigned getBitwidth(); + unsigned getVec(); + }]; +} + +def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> { + let arguments = (ins BoolAttr:$bCluster); + let assemblyFormat = "attr-dict"; +} + +def NVGPU_FenceMBarrierInitOp : NVGPU_Op<"fence_mbarrier_init", []> { + let assemblyFormat = "attr-dict"; +} + +def NVGPU_ClusterArriveOp : NVGPU_Op<"cluster_arrive", []> { + let arguments = (ins I1Attr:$relaxed); + + let assemblyFormat = "attr-dict"; +} + +def NVGPU_ClusterWaitOp : NVGPU_Op<"cluster_wait", []> { + let assemblyFormat = "attr-dict"; +} + +def NVGPU_TMAStoreTiledOp : NVGPU_Op<"tma_store_tiled", [MemoryEffects<[MemWrite]>]> { + let arguments = (ins I8Ptr_global:$tmaDesc, I8Ptr_shared:$src, I1:$pred, Variadic:$coords); + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + +def NVGPU_StoreMatrixOp : NVGPU_Op<"stmatrix", [MemoryEffects<[MemWrite]>]> { + let arguments = (ins I8Ptr_shared:$addr, Variadic:$datas); + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + +def NVGPU_OffsetOfStmatrixV4Op : NVGPU_Op<"offset_of_stmatrix_v4", []> { + let arguments = (ins I32:$threadId, I32:$rowOfWarp, I32:$elemIdx, I32Attr:$leadingDimOffset, I32Attr:$rowStride, I1Attr:$swizzleEnabled); + let results = (outs I32:$offset); + let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($offset)"; +} + +def NVGPU_OffsetOfSts64Op : NVGPU_Op<"offset_of_sts64", []> { + let arguments = (ins I32:$threadId, I32:$rowOfWarp, I32:$elemIdx, I32Attr:$leadingDimOffset, I32Attr:$rowStride, I1Attr:$swizzleEnabled); + let results = (outs I32:$offset); + let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($offset)"; +} + +def NVGPU_Sts64Op : NVGPU_Op<"sts64", [MemoryEffects<[MemWrite]>]> { + let arguments = (ins I32:$offset, AnyTypeOf<[F32, I32]>:$d0, AnyTypeOf<[F32, I32]>:$d1); + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + +def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> { + let results = (outs I32:$result); + let assemblyFormat = "attr-dict"; +} + +def NVGPU_RegAllocOp : NVGPU_Op<"reg_alloc", []> { + let arguments = (ins I32Attr: $regCount); + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + +def NVGPU_RegDeallocOp : NVGPU_Op<"reg_dealloc", []> { + let arguments = (ins I32Attr: $regCount); + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + +#endif diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index 16879f8e9910..dd763d3454b4 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -9,6 +9,8 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/FunctionInterfaces.h" +#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "triton/Dialect/Triton/IR/Dialect.h.inc" #include "triton/Dialect/Triton/IR/OpsEnums.h.inc" diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 112f164074b2..69cad2bcf2eb 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -9,6 +9,8 @@ include "mlir/IR/OpBase.td" include "mlir/IR/FunctionInterfaces.td" // FunctionOpInterface include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface +include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface +include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface include "mlir/Interfaces/SideEffectInterfaces.td" // Pure include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType @@ -135,7 +137,7 @@ def TT_LoadOp : TT_Op<"load", [SameLoadStoreOperandsAndResultShape, SameLoadStoreOperandsAndResultEncoding, AttrSizedOperandSegments, - MemoryEffects<[MemRead]>, + DeclareOpInterfaceMethods, TypesMatchWith<"infer ptr type from result type", "result", "ptr", "$_self", "mlir::OpTrait::impl::verifyLoadStorePointerAndValueType">, @@ -461,33 +463,23 @@ def TT_ScanReturnOp: TT_Op<"scan.return", // // External Elementwise op // -class TT_ExternElementwiseOpBase traits = []> : - TT_Op { +def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise, + SameOperandsAndResultEncoding, + SameVariadicOperandSize, + DeclareOpInterfaceMethods]> { let description = [{ call an external function $symbol implemented in $libpath/$libname with $args return $libpath/$libname:$symbol($args...) }]; - let arguments = (ins Variadic:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol); + let arguments = (ins Variadic:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure); let results = (outs TT_Type:$result); let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)"; } -def TT_PureExternElementwiseOp : TT_ExternElementwiseOpBase<"pure_extern_elementwise", [Pure, Elementwise]> { - let summary = "FFI for pure element-wise extern LLVM bitcode functions"; -} - -def TT_ImpureExternElementwiseOp : TT_ExternElementwiseOpBase<"impure_extern_elementwise", [MemoryEffects<[MemRead]>, - MemoryEffects<[MemWrite]>]> { - let summary = "FFI for impure element-wise extern LLVM bitcode functions"; -} - // // Make Range Op // @@ -506,6 +498,30 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> { let results = (outs TT_IntTensor:$result); let assemblyFormat = "attr-dict `:` type($result)"; + + let hasFolder = 1; +} + +// +// ElementwiseInlineAsm Op +// +def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [Elementwise, + SameOperandsAndResultEncoding, + DeclareOpInterfaceMethods]> { + let summary = "inline assembly applying elementwise operation to a group of packed element."; + let description = [{ + This will apply the given in inline assembly to `packed_element` number of + elements of the inputs. The elements packed together is unknown and will + depend on the backend implementation. + }]; + + let arguments = (ins StrAttr:$asm_string, StrAttr:$constraints, BoolAttr:$pure, I32Attr:$packed_element, Variadic>:$args); + let results = (outs TT_Type:$result); + + + let assemblyFormat = [{ + $asm_string attr-dict ($args^ `:` type($args))? `->` type($result) + }]; } // @@ -563,6 +579,7 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr", let results = (outs TT_TensorPtr:$result); + // TODO(Keren): define a custom assembly format for this op because the result type cannot be printed correctly // Add additional `[]` to increase readability and split variadic lists let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` `,` `[` $offsets `]` attr-dict `:` type($result)"; diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td index 77bbdaf5a24c..5ab1f99072fa 100644 --- a/include/triton/Dialect/Triton/IR/TritonTypes.td +++ b/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -14,7 +14,7 @@ class TritonTypeDef } // Floating-point Type -def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E4M3B11FNUZ, F8E5M2, F16, BF16, F32, F64], "floating-point">; +def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2, F16, BF16, F32, F64], "floating-point">; def TT_FloatTensor : TensorOf<[TT_Float]>; def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>; @@ -74,7 +74,7 @@ def TT_PtrType : TritonTypeDef<"Pointer", "ptr"> { // Scalar Pointer Type: `ptr<>` def TT_Ptr : TT_PtrOf<[AnyType]>; -// Tensor of Pointer Type +// Tensor of Pointer Type: `tensor>` def TT_PtrTensor : TensorOf<[TT_Ptr]>; // Tensor of Pointer Type or Pointer type: `tensor>` or `ptr<>` diff --git a/include/triton/Dialect/Triton/IR/Types.h b/include/triton/Dialect/Triton/IR/Types.h index 6c70c966d854..6a5d3124108c 100644 --- a/include/triton/Dialect/Triton/IR/Types.h +++ b/include/triton/Dialect/Triton/IR/Types.h @@ -14,6 +14,8 @@ namespace triton { bool isTensorPointerType(Type type); +bool isTensorOrTensorPointerType(Type type); + unsigned getPointeeBitWidth(Type type); Type getPointeeType(Type type); diff --git a/include/triton/Dialect/Triton/Transforms/Passes.h b/include/triton/Dialect/Triton/Transforms/Passes.h index 31c4887fde64..eb215ac11ca4 100644 --- a/include/triton/Dialect/Triton/Transforms/Passes.h +++ b/include/triton/Dialect/Triton/Transforms/Passes.h @@ -9,7 +9,6 @@ namespace triton { std::unique_ptr createCombineOpsPass(); std::unique_ptr createReorderBroadcastPass(); - std::unique_ptr createRewriteTensorPointerPass(int computeCapability = 80, bool isROCM = false); diff --git a/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt b/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt index 1d2faad405e6..d32192749f25 100644 --- a/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt @@ -3,9 +3,13 @@ mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu) mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu) mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_gpu) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_gpu) add_public_tablegen_target(TritonGPUTableGen) set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td) mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls) mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(TritonGPUAttrDefsIncGen) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index d73ad3aaec83..10192b0f44fc 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -7,6 +7,7 @@ #include "mlir/IR/Dialect.h" // TritonGPU depends on Triton +#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" @@ -73,17 +74,41 @@ getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape); SmallVector getThreadsPerCTA(Attribute layout); +SmallVector getOrder(Attribute layout); + +CTALayoutAttr getCTALayout(Attribute layout); + +SmallVector getCTAsPerCGA(Attribute layout); + +SmallVector getCTASplitNum(Attribute layout); + +SmallVector getCTAOrder(Attribute layout); + +/* The difference between ShapePerCTATile and ShapePerCTA: + * (1) ShapePerCTATile is defined by SizePerThread * ThreadsPerWarp * + * WarpsPerCTA in each dimension and is independent from the tensor shape. + * (2) ShapePerCTA is defined by shape / CTASplitNum in each dimension. + * (3) In the implementation of emitIndices, ShapePerCTATile will + * be replicated or wraped to fit ShapePerCTA. + */ SmallVector -getShapePerCTA(Attribute layout, - ArrayRef tensorShape = ArrayRef()); +getShapePerCTATile(Attribute layout, + ArrayRef tensorShape = ArrayRef()); -SmallVector getOrder(Attribute layout); +SmallVector getShapePerCTA(ArrayRef CTASplitNum, + ArrayRef shape); +SmallVector getShapePerCTA(Attribute layout, ArrayRef shape); +SmallVector getShapePerCTA(Type type); + +unsigned getNumWarpsPerCTA(Attribute layout); + +unsigned getNumCTAs(Attribute layout); bool isaDistributedLayout(Attribute layout); bool isSharedEncoding(Value value); -bool isExpensiveCat(CatOp cat, Attribute &targetEncoding); +bool isExpensiveCat(CatOp cat, Attribute targetEncoding); } // namespace gpu } // namespace triton diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 921d91cff6cf..947aa3f4dac9 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -41,6 +41,19 @@ Right now, Triton implements two classes of layouts: shared, and distributed. }]; } +//===----------------------------------------------------------------------===// +// CTA Layout +//===----------------------------------------------------------------------===// + +def CTALayoutAttr : TritonGPU_Attr<"CTALayout"> { + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$CTAsPerCGA, + ArrayRefParameter<"unsigned">:$CTASplitNum, + ArrayRefParameter<"unsigned">:$CTAOrder + ); +} + //===----------------------------------------------------------------------===// // Shared Layout Encoding //===----------------------------------------------------------------------===// @@ -64,19 +77,41 @@ are stored contiguously _ _ _ _ /\_ _ _ _ A_{2, 2} A_{2, 3} A_{2, 0} A_{2, 1} ... [phase 1] \ per phase = 2 A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / + +For MMAv3 eg Hopper GMMA, hasLeadingOffset should be true. In this case, +when the matrix is stored in shared memory, there will be an offset not +only in the stride dimension, but also in the leading dimension. For example, +a matrix of size 16x128 and data type I8 is stored in the shared memory with +64B-swizzle mode. The offset of the element with index (0, 64) will be 16*64, +compared to 1*64 when the hasLeadingOffset is false. }]; + // swizzle info: vec, perPhase, maxPhase + // order: the fastest-changing axis first let parameters = ( ins - // swizzle info - "unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase, - ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order + "unsigned":$vec, + "unsigned":$perPhase, + "unsigned":$maxPhase, + ArrayRefParameter<"unsigned">:$order, + "CTALayoutAttr":$CTALayout, + "bool":$hasLeadingOffset ); let builders = [ + AttrBuilder<(ins "unsigned":$vec, + "unsigned":$perPhase, + "unsigned":$maxPhase, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout), [{ + bool hasLeadingOffset = false; // default value + return $_get(context, vec, perPhase, maxPhase, order, CTALayout, hasLeadingOffset); + }]>, + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, "ArrayRef":$shape, "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, "unsigned":$typeWidthInBit), [{ #ifdef USE_ROCM @@ -112,20 +147,21 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit; int maxPhase = SIMDWidth / perPhase; - return $_get(context, vecSize, perPhase, maxPhase, order); + return get(context, vecSize, perPhase, maxPhase, order, CTALayout); } else { // Do not swizzle in case k dimension is not innermost. // In this case accesses will go in different banks even without swizzling. - return $_get(context, 1, 1, 1, order); + return get(context, 1, 1, 1, order, CTALayout); } } #endif auto mmaEnc = dotOpEnc.getParent().dyn_cast(); if(!mmaEnc) - return $_get(context, 1, 1, 1, order); + return get(context, 1, 1, 1, order, CTALayout); int opIdx = dotOpEnc.getOpIdx(); + auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); // number of rows per phase @@ -134,34 +170,34 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / // ---- begin Volta ---- if (mmaEnc.isVolta()) { - int perPhase = 128 / (shape[order[0]] * (typeWidthInBit / 8)); + int perPhase = 128 / (shapePerCTA[order[0]] * (typeWidthInBit / 8)); perPhase = std::max(perPhase, 1); bool is_row = order[0] != 0; - bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) : - is_row && (shape[order[0]] <= 16); + bool is_vec4 = opIdx == 0 ? !is_row && (shapePerCTA[order[0]] <= 16) : + is_row && (shapePerCTA[order[0]] <= 16); int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) : ((is_row && !is_vec4) ? 2 : 1); int rep = 2 * pack_size; int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase; int vec = 2 * rep; - return $_get(context, vec, perPhase, maxPhase, order); + return get(context, vec, perPhase, maxPhase, order, CTALayout); } // ---- begin Ampere ---- if (mmaEnc.isAmpere()) { - int perPhase = 128 / (shape[order[0]] * 4 / dotOpEnc.getKWidth()); + int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth()); perPhase = std::max(perPhase, 1); std::vector matShape = {8, 8, 4 * dotOpEnc.getKWidth()}; // for now, disable swizzle when using transposed int8 tensor cores if ((32 / typeWidthInBit != dotOpEnc.getKWidth()) && order[0] == inner) - return $_get(context, 1, 1, 1, order); + return get(context, 1, 1, 1, order, CTALayout); // --- handle A operand --- if (opIdx == 0) { // compute swizzling for A operand int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2]; int maxPhase = mmaStride / perPhase; - return $_get(context, vec, perPhase, maxPhase, order); + return get(context, vec, perPhase, maxPhase, order, CTALayout); } // --- handle B operand --- @@ -169,12 +205,19 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1]; int maxPhase = mmaStride / perPhase; - return $_get(context, vec, perPhase, maxPhase, order); + return get(context, vec, perPhase, maxPhase, order, CTALayout); } llvm_unreachable("invalid operand index"); } + // ---- begin version 3 ---- + if (mmaEnc.isHopper()) { + llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr" + " is Hopper has not been implemented yet"); + return $_get(context, 1, 1, 1, order, CTALayout, true); + } + // ---- not implemented ---- llvm_unreachable("unsupported swizzling for provided MMA version"); }]>, @@ -182,9 +225,38 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, "ArrayRef":$shape, "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, "Type":$eltTy), [{ unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); - return get(context, dotOpEnc, shape, order, bitwidth); + return get(context, dotOpEnc, shape, order, CTALayout, bitwidth); + }]>, + + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy), [{ + auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); + + int32_t eleBitWidth = eltTy.getIntOrFloatBitWidth(); + int32_t vec = 128 / eleBitWidth, perPhase = 1, maxPhase = 1; + + // get proper shared memory swizzling mode from the contiguous dimension + // size of the origin blocked layout. + auto contigDimSizeInByte = shapePerCTA[order[0]] * eleBitWidth / 8; + if (contigDimSizeInByte >= 128 && contigDimSizeInByte % 128 == 0) { + perPhase = 1; + maxPhase = 8; + } else if (contigDimSizeInByte >= 64 && contigDimSizeInByte % 64 == 0) { + perPhase = 2; + maxPhase = 4; + } else if (contigDimSizeInByte >= 32 && contigDimSizeInByte % 32 == 0) { + perPhase = 4; + maxPhase = 2; + } else { + llvm_unreachable("unsupported shared memory layout for MMAv3"); + } + + return $_get(context, vec, perPhase, maxPhase, order, CTALayout, true); }]> ]; @@ -236,7 +308,7 @@ used to promote memory coalescing in LoadInst and StoreInst. It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which specify the amount of elements owned by each CUDA thread, warp and CTA respectively. -For example, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows. +Example 1, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows: [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] @@ -252,60 +324,136 @@ for sizePerThread = {2, 2} threadsPerWarp = {8, 4} warpsPerCTA = {1, 2} + CTAsPerCGA = {1, 1} +}> + +Example 2, a row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) as follows: + +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +for + +#triton_gpu.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} + CTAsPerCGA = {1, 1} +}> + +Example 3, A row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) and +4 CTAs (taking 2x2 for example) as follows: + +CTA [0,0] CTA [0,1] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] + +CTA [1,0] CTA [1,1] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +for + +#triton_gpu.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} + CTAsPerCGA = {2, 2} }> }]; + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$sizePerThread, + ArrayRefParameter<"unsigned">:$threadsPerWarp, + ArrayRefParameter<"unsigned">:$warpsPerCTA, + ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first + "CTALayoutAttr":$CTALayout + ); let builders = [ - // Custom builder initializes sizePerWarp and sizePerCTA automatically - // TODO: compiles on MacOS but not linux? - // AttrBuilder<(ins "ArrayRef":$sizePerThread, - // "ArrayRef":$threadsPerWarp, - // "ArrayRef":$warpsPerCTA, - // "ArrayRef":$order), [{ - // int rank = threadsPerWarp.size(); - // SmallVector sizePerWarp(rank); - // SmallVector sizePerCTA(rank); - // for (unsigned i = 0; i < rank; i++) { - // sizePerWarp.push_back(sizePerThread[i] * threadsPerWarp[i]); - // sizePerCTA.push_back(sizePerWarp[i] * warpsPerCTA[i]); - // } - // return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, sizePerWarp, sizePerCTA); - // }]>, - // Custom builder initializes sizePerWarp and sizePerCTA automatically - // Default builder takes sizePerThread, order and numWarps, and tries to - // pack numWarps*32 threads in the provided order for use in a type - // of the given shape. AttrBuilder<(ins "ArrayRef":$shape, "ArrayRef":$sizePerThread, "ArrayRef":$order, "unsigned":$numWarps, - "unsigned":$threadsPerWarp), [{ - int rank = sizePerThread.size(); - unsigned remainingLanes = threadsPerWarp; - unsigned remainingThreads = numWarps*threadsPerWarp; + "unsigned":$numThreadsPerWarp, + "CTALayoutAttr":$CTALayout), [{ + unsigned rank = sizePerThread.size(); + SmallVector threadsPerWarp(rank); + SmallVector warpsPerCTA(rank); + SmallVector shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); + + unsigned remainingLanes = numThreadsPerWarp; + unsigned remainingThreads = numWarps * numThreadsPerWarp; unsigned remainingWarps = numWarps; unsigned prevLanes = 1; unsigned prevWarps = 1; - SmallVector rankedThreadsPerWarp(rank); - SmallVector warpsPerCTA(rank); - for (int _dim = 0; _dim < rank - 1; ++_dim) { - int i = order[_dim]; - unsigned threadsPerCTA = std::clamp(remainingThreads, 1, shape[i] / sizePerThread[i]); - rankedThreadsPerWarp[i] = std::clamp(threadsPerCTA, 1, remainingLanes); - warpsPerCTA[i] = std::clamp(threadsPerCTA / rankedThreadsPerWarp[i], 1, remainingWarps); + + // starting from the contiguous dimension + for (unsigned d = 0; d < rank - 1; ++d) { + unsigned i = order[d]; + unsigned threadsPerCTA = std::clamp(remainingThreads, 1, shapePerCTA[i] / sizePerThread[i]); + threadsPerWarp[i] = std::clamp(threadsPerCTA, 1, remainingLanes); + warpsPerCTA[i] = std::clamp(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps); remainingWarps /= warpsPerCTA[i]; - remainingLanes /= rankedThreadsPerWarp[i]; + remainingLanes /= threadsPerWarp[i]; remainingThreads /= threadsPerCTA; - prevLanes *= rankedThreadsPerWarp[i]; + prevLanes *= threadsPerWarp[i]; prevWarps *= warpsPerCTA[i]; } + // Expand the last dimension to fill the remaining lanes and warps - rankedThreadsPerWarp[order[rank-1]] = threadsPerWarp / prevLanes; - warpsPerCTA[order[rank-1]] = numWarps / prevWarps; + threadsPerWarp[order[rank - 1]] = numThreadsPerWarp / prevLanes; + warpsPerCTA[order[rank - 1]] = numWarps / prevWarps; - return $_get(context, sizePerThread, rankedThreadsPerWarp, warpsPerCTA, order); + return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, CTALayout); + }]>, + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$sizePerThread, + "ArrayRef":$order, + "unsigned":$numWarps, + "unsigned":$numThreadsPerWarp, + "unsigned":$numCTAs), [{ + unsigned rank = sizePerThread.size(); + SmallVector CTAsPerCGA(rank); + SmallVector CTASplitNum(rank); + ArrayRef CTAOrder = order; + + unsigned remainingCTAs = numCTAs; + + // starting from the most strided dimension + for (int d = rank - 1; d >= 0; --d) { + unsigned i = order[d]; + CTAsPerCGA[i] = std::clamp(remainingCTAs, 1, shape[i] / sizePerThread[i]); + CTASplitNum[i] = CTAsPerCGA[i]; + remainingCTAs /= CTAsPerCGA[i]; + } + + CTAsPerCGA[rank - 1] *= remainingCTAs; // wrap at CTA level + + CTALayoutAttr CTALayout = CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + return get(context, shape, sizePerThread, order, numWarps, numThreadsPerWarp, CTALayout); }]> ]; @@ -313,21 +461,6 @@ for SliceEncodingAttr squeeze(int axis); }]; - let parameters = ( - ins - ArrayRefParameter<"unsigned">:$sizePerThread, - ArrayRefParameter<"unsigned">:$threadsPerWarp, - ArrayRefParameter<"unsigned">:$warpsPerCTA, - // fastest-changing axis first - ArrayRefParameter< - "unsigned", - "order of axes by the rate of changing" - >:$order - // These attributes can be inferred from the rest - // ArrayRefParameter<"unsigned">:$sizePerWarp, - // ArrayRefParameter<"unsigned">:$sizePerCTA - ); - let hasCustomAssemblyFormat = 1; } @@ -423,13 +556,17 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: ins "unsigned":$versionMajor, "unsigned":$versionMinor, - ArrayRefParameter<"unsigned">:$warpsPerCTA + ArrayRefParameter<"unsigned">:$warpsPerCTA, + "CTALayoutAttr":$CTALayout, + ArrayRefParameter<"unsigned">:$instrShape ); let builders = [ // Specially for MMAV1(Volta) AttrBuilder<(ins "int":$versionMajor, "int":$numWarps, + "CTALayoutAttr":$CTALayout, + "ArrayRef":$instrShape, "ArrayRef":$shapeC, "bool":$isARow, "bool":$isBRow, @@ -443,7 +580,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: (isAVec4 * (1<<2)) |\ (isBVec4 * (1<<3)); - // TODO: Share code with // DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the // rep,spw and fpw. @@ -468,11 +604,13 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: wpt[1] = std::clamp(wpt[1] * 2, 1, shapeC[1] / spw[1]); } while (wpt_nm1 != wpt); - return $_get(context, versionMajor, versionMinor, wpt); + return $_get(context, versionMajor, versionMinor, wpt, CTALayout, instrShape); }]>, AttrBuilder<(ins "int":$versionMajor, "int":$numWarps, + "CTALayoutAttr":$CTALayout, + "ArrayRef":$instrShape, "ArrayRef":$shapeA, "ArrayRef":$shapeB, "ArrayRef":$shapeC, @@ -482,15 +620,21 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: assert(versionMajor == 1 && "This builder is specially for versionMajor==1"); bool isAVec4 = !isARow && (shapeA[isARow] <= 16); bool isBVec4 = isBRow && (shapeB[isBRow] <= 16); - return get(context, versionMajor, numWarps, shapeC, isARow, isBRow, isAVec4, isBVec4, id); + return get(context, versionMajor, numWarps, CTALayout, instrShape, shapeC, isARow, isBRow, isAVec4, isBVec4, id); }]> ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ bool isVolta() const; + bool isTuring() const; bool isAmpere() const; + bool isHopper() const; + + unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef shape) const; + // Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor std::tuple decodeVoltaLayoutStates() const; + // Number of bits in versionMinor to hold the ID of the MMA encoding instance. // Here 5 bits can hold 32 IDs in a single module. static constexpr int numBitsToHoldMmaV1ID{5}; @@ -574,7 +718,8 @@ The data will be distributed between threads as follows: ins "unsigned":$nonKDim, ArrayRefParameter<"unsigned">:$warpsPerCTA, - "bool":$isTransposed + "bool":$isTransposed, + "CTALayoutAttr":$CTALayout ); let hasCustomAssemblyFormat = 1; @@ -670,6 +815,4 @@ section 9.7.13.4.1 for more details. }]; } - - #endif diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td index e35ee2b576c3..a4cc9eca0ff5 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -16,6 +16,7 @@ def TritonGPU_Dialect : Dialect { let dependentDialects = [ "triton::TritonDialect", + "mlir::triton::nvgpu::NVGPUDialect", "mlir::gpu::GPUDialect", "tensor::TensorDialect", ]; @@ -23,14 +24,27 @@ def TritonGPU_Dialect : Dialect { let extraClassDeclaration = [{ static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; } static int getNumWarps(ModuleOp mod) { - Attribute numWarps = mod->getDiscardableAttr("triton_gpu.num-warps"); - if(!numWarps) + if(!mod->hasAttr("triton_gpu.num-warps")) llvm::report_fatal_error( "TritonGPU module should contain a triton_gpu.num-warps attribute"); - return numWarps.cast().getInt(); + return mod->getAttr("triton_gpu.num-warps").cast().getInt(); } + static int getNumCTAs(ModuleOp mod) { + if(!mod->hasAttr("triton_gpu.num-ctas")) + llvm::report_fatal_error( + "TritonGPU module should contain a triton_gpu.num-ctas attribute"); + return mod->getAttr("triton_gpu.num-ctas").cast().getInt(); + } + static int getComputeCapability(ModuleOp mod) { + if(!mod->hasAttr("triton_gpu.compute-capability")) + llvm::report_fatal_error( + "TritonGPU module should contain a triton_gpu.compute-capability attribute"); + return mod->getAttrOfType("triton_gpu.compute-capability").getInt(); + } + void registerTypes(); static std::string getThreadsPerWarpAttrName() { return "triton_gpu.threads-per-warp"; } + static int getThreadsPerWarp(ModuleOp mod) { Attribute threadsPerWarp = mod->getDiscardableAttr("triton_gpu.threads-per-warp"); if(!threadsPerWarp) { @@ -38,6 +52,13 @@ def TritonGPU_Dialect : Dialect { } return threadsPerWarp.cast().getInt(); } + static int getSharedSize(ModuleOp mod) { + Attribute sharedAttr = mod->getDiscardableAttr("triton_gpu.shared"); + if(!sharedAttr) { + return 0; + } + return sharedAttr.cast().getInt(); + } }]; diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index b774fd54684d..d91fa076479b 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -2,6 +2,7 @@ #define TRITONGPU_OPS include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" include "mlir/Dialect/Arith/IR/ArithBase.td" include "triton/Dialect/Triton/IR/TritonTypes.td" @@ -46,6 +47,20 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> { }]; } +def TTG_AsyncBulkWaitOp : TTG_Op<"async_bulk_wait"> { + let summary = "async bulk wait"; + + let arguments = (ins I32Attr:$num); + + let assemblyFormat = "attr-dict"; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 90; + } + }]; +} + def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> { let summary = "async commit group"; @@ -58,6 +73,18 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> { }]; } +def TTG_AsyncBulkCommitGroupOp : TTG_Op<"async_bulk_commit_group"> { + let summary = "async bulk commit group"; + + let assemblyFormat = "attr-dict"; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 90; + } + }]; +} + // Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU. // This is needed because these ops don't @@ -106,6 +133,98 @@ def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise, let results = (outs TT_Type:$result); } +// TODO[goostavz]: extract a base class for InsertSlice & InsertSliceAsync once the op definition is verified +def TTG_InsertSliceOp : TTG_Op<"insert_slice", + [AttrSizedOperandSegments, + ResultsAreSharedEncoding, + MemoryEffects<[MemRead, MemWrite]>, + TypesMatchWith<"infer mask type from src type", + "src", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 3) || std::equal_to<>()">, + TypesMatchWith<"infer other type from src type", + "src", "other", "getPointeeType($_self)", + "($_op.getOperands().size() <= 4) || std::equal_to<>()">]> { + let summary = "insert slice"; + + let description = [{ + This operation inserts a tensor `$src` into another tensor `$dst` as specified by the operation’s + `$index` argument and `$axis` attribute. + + It returns a copy of `$dst` with the proper slice updated with the value of `$src`. + + When converting from `tt.load` to `triton_gpu.insert_slice`, the `$evict`, `$cache`, and `$isVolatile` fields + might be ignored on certain hardware. For example, on NVIDIA GPUs, the cache policy is determined by the backend, + and `$evict` and `$isVolatile` are ignored because they apply to L1 cache only. + + The insert_slice operation supports the following arguments: + + * src: the tensor that is inserted. + * dst: the tensor into which the `$src` tensor is inserted. + * index: the index of the `$src` tensor at the given `$axis` from which the `$dst` tensor is inserted into + * mask: optional tensor-rank number of boolean masks which specify which + elements of the `$src` tensor are inserted into the `$dst` tensor. + * other: optional tensor-rank number of other tensors which specify what + values are inserted into the `$dst` tensor if the corresponding + element of the `$mask` tensor is false. + + ttgpu.load_tile_async depracate + triton_gpu.insert_slice might be further lowered into triton_gpu_async for different hardware implementations + + like tt.load, ttgpu.insert_slice/insert_slice_async has two modes up to the type of src + mode 1: ptr/src is a tensor of pointers + mode 2: ptr/src is a tensor pointer + + Some typical lowering paths are: + in case the load is pipelined by the pipeline pass( load is inside kBlock loop, which means "pipeline pass): + Load from global + store to shared : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -(Pipeline)-> ttgpu.insert_slice(mode 1) + Non-bulk cp.async : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -(Pipeline)-> ttgpu.insert_slice(mode 1) -(MaterializeLoad)> ttgpu.insert_slice_async(mode 1) + ttgpu.await-> llvm + TMA load : tt.load(mode 2) -(tt->ttgpu+Coalesce)-> tt.load(mode 2) -(Pipeline)-> ttgpu.insert_slice(mode 2) -(MaterializeLoad)> ttgpu.insert_slice_async_v2(mode 2) + ttgpu.await-> llvm + + otherwise: + Load from global + store to shared : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) + Non-bulk cp.async : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -> ... -(MaterializeLoad)-> ttgpu.insert_slice_async(mode 1) + ttgpu.await -> llvm + TMA load : tt.load(mode 2) -(tt->ttgpu+Coalesce)-> tt.load(mode 2) -> ... -(MaterializeLoad)-> ttgpu.insert_slice_async(mode 2) + ttgpu.await -> llvm + + Example: + + ``` + %1 = triton_gpu.alloc_tensor : tensor<2x32xf32> + %2 = triton_gpu.insert_slice %0, %1, %index { axis = 0 } : tensor<32x!tt.ptr, #AL> -> tensor<2x32xf32, #A> + ``` + }]; + + let arguments = (ins TT_PtrLike:$src, TT_Tensor:$dst, I32:$index, + Optional:$mask, Optional:$other, + TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict, + BoolAttr:$isVolatile, I32Attr:$axis); + + let builders = [ + OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index, + "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>, + OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index, "Value":$mask, + "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>, + OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index, + "Value":$mask, "Value":$other, + "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>, + ]; + + let results = (outs TT_Tensor:$result); + + let extraClassDeclaration = [{ + static DenseSet getEligibleLoadByteWidth(int computeCapability) { + DenseSet validLoadBytes; + if (computeCapability >= 80) { + validLoadBytes = {4, 8, 16}; + } + return validLoadBytes; + } + }]; + + let hasCustomAssemblyFormat = 1; +} def TTG_ExtractSliceOp : TTG_Op<"extract_slice", @@ -173,7 +292,8 @@ def TTG_ExtractSliceOp : TTG_Op<"extract_slice", def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", [AttrSizedOperandSegments, ResultsAreSharedEncoding, - MemoryEffects<[MemRead]>, + // TODO: Check if MemWrite will degrade performance of non-warp-specialized kernel + MemoryEffects<[MemRead, MemWrite]>, TypesMatchWith<"infer mask type from src type", "src", "mask", "getI1SameShape($_self)", "($_op.getOperands().size() <= 3) || std::equal_to<>()">, @@ -219,7 +339,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", ``` }]; - let arguments = (ins TT_PtrTensor:$src, TT_Tensor:$dst, I32:$index, + let arguments = (ins TT_PtrLike:$src, TT_Tensor:$dst, I32:$index, Optional:$mask, Optional:$other, TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict, BoolAttr:$isVolatile, I32Attr:$axis); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td new file mode 100644 index 000000000000..aa831e7c4c1e --- /dev/null +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td @@ -0,0 +1,26 @@ +#ifndef TRITONGPU_TYPES +#define TRITONGPU_TYPES + +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" +include "mlir/IR/AttrTypeBase.td" + +class TTG_TypeDef traits = []> + : TypeDef { + let mnemonic = _mnemonic; +} + +def TTG_TokenType : TTG_TypeDef<"Token", "token"> { + let parameters = (ins "int32_t":$type); + + let builders = [ + TypeBuilder<(ins "unsigned":$type), [{ + return $_get($_ctxt, type); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; +} + +#endif diff --git a/include/triton/Dialect/TritonGPU/IR/Types.h b/include/triton/Dialect/TritonGPU/IR/Types.h new file mode 100644 index 000000000000..edf37fef606d --- /dev/null +++ b/include/triton/Dialect/TritonGPU/IR/Types.h @@ -0,0 +1,10 @@ +#ifndef TRITONGPU_IR_TYPES_H_ +#define TRITONGPU_IR_TYPES_H_ + +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/Types.h.inc" + +#endif // TRITON_IR_TYPES_H_ diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index 20d5bd3f14af..89b3d818c072 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -2,9 +2,14 @@ #define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_ #include "mlir/Pass/Pass.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" namespace mlir { -std::unique_ptr createTritonGPUPipelinePass(int numStages = 2); + +std::unique_ptr createTritonGPUPipelinePass(int numStages = 3, + int numWarps = 4, + int numCTAs = 1, + int computeCapability = 80); std::unique_ptr createTritonGPUStreamPipelinePass(); @@ -27,6 +32,8 @@ std::unique_ptr createTritonGPUVerifier(); std::unique_ptr createTritonGPUOptimizeDotOperandsPass(); +std::unique_ptr createTritonGPUOptimizeEpiloguePass(); + /// Generate the code for registering passes. #define GEN_PASS_REGISTRATION #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 3a3e6775fe50..81d1b1de7ba9 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -14,13 +14,23 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { let constructor = "mlir::createTritonGPUPipelinePass()"; let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", "mlir::scf::SCFDialect", "mlir::arith::ArithDialect"]; let options = [ Option<"numStages", "num-stages", - "int32_t", /*default*/"2", - "number of pipeline stages"> + "int32_t", /*default*/"3", + "number of pipeline stages">, + Option<"numWarps", "num-warps", + "int32_t", /*default*/"4", + "number of warps per block">, + Option<"numCTAs", "num-ctas", + "int32_t", /*default*/"1", + "number of CTAs per CGA">, + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"80", + "device compute capability"> ]; } @@ -65,6 +75,7 @@ def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::Modul let constructor = "mlir::createTritonGPUAccelerateMatmulPass()"; let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", "mlir::triton::TritonDialect"]; let options = [ @@ -85,6 +96,7 @@ def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir let constructor = "mlir::createTritonGPUOptimizeDotOperandsPass()"; let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", "mlir::triton::TritonDialect"]; } @@ -111,6 +123,20 @@ def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", "mlir::triton::TritonDialect"]; + +} + +def TritonGPUOptimizeEpilogue : Pass<"tritongpu-optimize-epilogue", "mlir::ModuleOp"> { + let summary = "Optimize epilogue: (1) Store accumulators directly without going thorough SMEM in epilogue."; + + let description = [{ + }]; + + let constructor = "mlir::createTritonGPUOptimizeEpiloguePass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; + } def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::ModuleOp"> { diff --git a/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h b/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h index 8730c8b9a2cc..fbfa235fc6bb 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h +++ b/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h @@ -13,15 +13,17 @@ namespace mlir { class TritonGPUTypeConverter : public TypeConverter { public: - TritonGPUTypeConverter(MLIRContext *context, int numWarps, - int threadsPerWarp); + TritonGPUTypeConverter(MLIRContext *context, int numWarps, int threadsPerWarp, + int numCTAs); int getNumWarps() const { return numWarps; } int getThreadsPerWarp() const { return threadsPerWarp; } + int getNumCTAs() const { return numCTAs; } private: MLIRContext *context; int numWarps; int threadsPerWarp; + int numCTAs; }; class TritonGPUConversionTarget : public ConversionTarget { diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 92732d5797f8..6c0193182336 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -10,33 +10,143 @@ namespace mlir { -LogicalResult fixupLoops(ModuleOp mod); +namespace triton { +class LoadOp; +class StoreOp; +class FuncOp; +namespace gpu { +class SharedEncodingAttr; +} +} // namespace triton -// TODO: Interface -LogicalResult invertEncoding(Attribute targetEncoding, Operation *op, - Attribute &ret); +SmallVector mmaVersionToInstrShape(int version, + const ArrayRef &shape, + RankedTensorType type); -bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding); +/// Returns true if the Load is for TMA +bool isLoadFromTensorPtr(triton::LoadOp op); -bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding); +/// Returns true if the store is for TMA +bool isStoreToTensorPtr(triton::StoreOp op); -// skipInit is True when we only consider the operands of the initOp but -// not the initOp itself. -int simulateBackwardRematerialization( - Operation *initOp, SetVector &processed, - SetVector &layout, llvm::MapVector &toConvert, - Attribute targetEncoding); +/// Return the first consumer of v +Operation *getFirstUser(Value v); + +/// Return the proper SharedEncodingAttr according to shape/order +triton::gpu::SharedEncodingAttr getSharedEncoding(RankedTensorType tensorTy); + +/* Dump Triton IR in graphviz dot format. + * + * You can override `onValue` and `onOperation` in a subclass to mark + * specific Values and Operations. The below subclass + * GraphLayoutMarker is an example. + * + * Default NodeInfo for Value nodes: + * {{"shape": "box"}, + * {"style", "filled"}, + * {"fillcolor", "white"}, + * {"label", shapeStr}} + * + * Default NodeInfo for Operation nodes: + * {{"shape": "ellipse"}, + * {"style", "filled"}, + * {"fillcolor", "white"}, + * {"label", operationName}} + * + * If the key "label" is not set by `onValue` or `onOperation`, default labels + * will be generated. For Value node, the default label is the shape string and + * for Operation node, it is the operation name. + * + * Reference: + * https://graphviz.org/doc/info/shapes.html + * https://graphviz.org/doc/info/colors.html + * + * Usage: + * C++: GraphDumper().dumpToFile(func, "func.dot"); + * Shell: dot -Tjpg func.dot -o func.jpg + */ +class GraphDumper { +public: + using NodeInfo = std::map; + + // Override this function to mark specific Values + virtual NodeInfo onValue(Value value) const; + // Override this function to mark specific Operations + virtual NodeInfo onOperation(Operation *op) const; + + std::string dump(triton::FuncOp func) const; + void dumpToFile(triton::FuncOp func, const std::string &filename) const; + +protected: + std::string getShapeStr(const Type &type) const; + + std::string getUniqueId(Value value) const; + std::string getUniqueId(Operation *op) const; + + std::string emitNode(const std::string &id, const NodeInfo style) const; + std::string emitEdge(const std::string &srcId, + const std::string &destId) const; + + std::string emitValueNode(Value value) const; + std::string emitOperationNode(Operation *op) const; +}; + +/* A subclass of GraphDumper that marks different layout kinds in different + * colors.*/ +class GraphLayoutMarker : public GraphDumper { +public: + NodeInfo onValue(Value value) const override; + +protected: + std::string getColor(const Type &type) const; +}; + +// Infers the encoding of the result of op given the source encoding. +std::optional inferDstEncoding(Operation *op, Attribute encoding); + +// Infers the encoding of the source of op given the result encoding. +std::optional inferSrcEncoding(Operation *op, Attribute encoding); + +bool isExpensiveLoadOrStore(Operation *op); + +bool canFoldIntoConversion(Operation *op, Attribute targetEncoding); Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, IRMapping &mapping); -void rematerializeConversionChain( - const llvm::MapVector &toConvert, - mlir::PatternRewriter &rewriter, SetVector &processed, - IRMapping &mapping); +// Get backward slice of tensor values starting from the root node along with +// encoding propagation. +LogicalResult getConvertBackwardSlice( + Value root, SetVector &slice, Attribute rootEncoding, + DenseMap &layout, + std::function stopPropagation = nullptr); + +// Populate pattern to remove dead cycles in ForOp. +void populateForOpDeadArgumentElimination(RewritePatternSet &patterns); + +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape, + ArrayRef order); + +SmallVector delinearize(OpBuilder &b, Location loc, unsigned linear, + ArrayRef shape); + +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape); +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order); + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape); -LogicalResult canMoveOutOfLoop(BlockArgument arg, - SmallVector &cvts); +// Returns null if the op is not inside a agent region (warp specialization +// mode). Note that there should be at most one agent id attached to the +// operation. +std::optional getWSAgentId(Operation *op); +std::optional getWSRoleId(Operation *op); +void setRoleId(Operation *op, int roleId); } // namespace mlir diff --git a/include/triton/Dialect/TritonNvidiaGPU/CMakeLists.txt b/include/triton/Dialect/TritonNvidiaGPU/CMakeLists.txt new file mode 100644 index 000000000000..9f57627c321f --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt b/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt new file mode 100644 index 000000000000..aba08ab137d3 --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt @@ -0,0 +1,15 @@ +set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_nvidia_gpu) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_nvidia_gpu) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_nvidia_gpu) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_nvidia_gpu) +add_public_tablegen_target(TritonNvidiaGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUAttrDefs.td) +mlir_tablegen(TritonNvidiaGPUAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TritonNvidiaGPUAttrDefs.cpp.inc -gen-attrdef-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(TritonNvidiaGPUAttrDefsIncGen) diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h new file mode 100644 index 000000000000..680af81ac41c --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_ + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" + +// TritonNvidiaGPU depends on Triton +#include "triton/Dialect/NVGPU/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Traits.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc" +#include "triton/Dialect/TritonNvidiaGPU/IR/Traits.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc" + +#endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_ diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/Traits.h b/include/triton/Dialect/TritonNvidiaGPU/IR/Traits.h new file mode 100644 index 000000000000..1db22527716b --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/Traits.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_NVIDIA_GPU_IR_TRAITS_H_ +#define TRITON_NVIDIA_GPU_IR_TRAITS_H_ + +#include "mlir/IR/OpDefinition.h" + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace OpTrait { + +// These functions are out-of-line implementations of the methods in the +// corresponding trait classes. This avoids them being template +// instantiated/duplicated. +namespace impl { +LogicalResult verifySource1IsSharedEncoding(Operation *op); +} // namespace impl + +template +class Source1IsSharedEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySource1IsSharedEncoding(op); + } +}; +} // namespace OpTrait +} // namespace mlir + +#endif diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td new file mode 100644 index 000000000000..936535bb039a --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td @@ -0,0 +1,29 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_ATTRDEFS +#define TRITONNVIDIAGPU_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" + +#endif diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td new file mode 100644 index 000000000000..08ff21f523f0 --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td @@ -0,0 +1,82 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_DIALECT +#define TRITONNVIDIAGPU_DIALECT + +include "mlir/IR/OpBase.td" + +def TritonNvidiaGPU_Dialect : Dialect { + let name = "triton_nvidia_gpu"; + + let cppNamespace = "::mlir::triton::nvidia_gpu"; + + let hasOperationAttrVerify = 1; + + let description = [{ + Triton Nvidia GPU Dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect", + "triton::gpu::TritonGPUDialect", + "mlir::triton::nvgpu::NVGPUDialect", + "mlir::gpu::GPUDialect", + "tensor::TensorDialect", + ]; + + let extraClassDeclaration = [{ + static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; } + static int getNumWarps(ModuleOp mod) { + if(!mod->hasAttr("triton_gpu.num-warps")) + llvm::report_fatal_error( + "TritonGPU module should contain a triton_gpu.num-warps attribute"); + return mod->getAttr("triton_gpu.num-warps").cast().getInt(); + } + static int getNumCTAs(ModuleOp mod) { + if(!mod->hasAttr("triton_gpu.num-ctas")) + llvm::report_fatal_error( + "TritonGPU module should contain a triton_gpu.num-ctas attribute"); + return mod->getAttr("triton_gpu.num-ctas").cast().getInt(); + } + static int getComputeCapability(ModuleOp mod) { + if(!mod->hasAttr("triton_gpu.compute-capability")) + llvm::report_fatal_error( + "TritonGPU module should contain a triton_gpu.compute-capability attribute"); + return mod->getAttrOfType("triton_gpu.compute-capability").getInt(); + } + void registerTypes(); + + // Warp specialization related: + static std::string getWSSupportedAttrName() { return "triton_gpu.enable-warp-specialization"; } + static int getWSSupportedAttr(ModuleOp mod) { + auto name = getWSSupportedAttrName(); + if (!mod->hasAttr(name)) return 0; + return mod->getAttrOfType(name).getInt(); + } + }]; + + let useDefaultTypePrinterParser = 1; +} + +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td" + +#endif diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td new file mode 100644 index 000000000000..cdf1146900c8 --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -0,0 +1,385 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_OPS +#define TRITONNVIDIAGPU_OPS + +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td" +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td" +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td" +include "mlir/Dialect/Arith/IR/ArithBase.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/ViewLikeInterface.td" + +def Source1IsSharedEncoding: NativeOpTrait<"Source1IsSharedEncoding">; + +def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">; + +class TTNG_Op traits = []> : + Op; + +// -------------------------------------------------------------------------------------------------- +// MBarrier related Ops: +// 1, These mbarrier commands are currently not needed, and not taken into consideration: +// (1), mbarrier.expect_tx +// (2), mbarrier.arrive_drop +// (3), mbarrier.complete_tx +// (4), mbarrier.inval +// +// 2, The mbarriers is supported to be created in vector, and accessed in seperate via tensor.extract. +// The mbarriers created in vector will have counters initialized in the same configuration. A +// typical example to demonstrate this: +// +// %1 = triton_nvidia_gpu.alloc_mbarrier { count = 1 } : tensor<4x!tt.ptr> +// scf.for %iv = %lb to %ub step %step iter_args() -> () { +// %buffer_id = arith.remi %iv, %c4 : i32 +// %2 = triton_nvidia_gpu.extract_mbarrier %1[%buffer_id] : tensor<4xi64>, i32 -> !tt.ptr +// triton_nvidia_gpu.mbarrier_arrive %2 {expectTx = 2048} : !tt.ptr -> () +// } +// ... +// scf.for %iv = %lb to %ub step %step iter_args() -> () { +// %buffer_id = arith.remi %iv, %c4 : i32 +// %2 = triton_nvidia_gpu.extract_mbarrier %1[%buffer_id] : tensor<4xi64>, i32 -> !tt.ptr +// triton_nvidia_gpu.mbarrier_wait %2, %c0 : !tt.ptr, i1 -> () +// } + +def TTNG_AllocMBarrierOp : TTNG_Op<"alloc_mbarrier", [MemoryEffects<[MemAlloc]>]> { + let summary = "allocate a vector of mbarriers"; + + let description = [{ + Allocate and initialize a vector of mbarriers. The size of the vector is implied in the returned type. + Each mbarrier is initialized as: + 1, the current phase initialized to 0. + 2, the expected arrival count initialized to 'count'. + 3, the pending arrival count initialized to 'count'. + 4, the tx-count initialized to 0. + + Example: + + case a. when created in vector: + %1 = triton_nvidia_gpu.alloc_mbarrier { count = 1 } : tensor<4xi64> + + case b. when created in scalar: + %1 = triton_nvidia_gpu.alloc_mbarrier { count = 1 } : !tt.ptr + + }]; + + let assemblyFormat = [{attr-dict `:` type($result)}]; + + let arguments = (ins I32Attr:$count); + + let results = (outs AnyTypeOf<[TT_Ptr, I64Tensor]>:$result); +} + +def TTNG_ExtractMBarrierOp : TTNG_Op<"extract_mbarrier", [Pure]> { + let summary = "extract a mbarrier from a vector of mbarriers"; + + let description = [{ + Extract a mbarrier from a vector of mbarriers + + Example: + + %1 = triton_nvidia_gpu.extract_mbarrier %mbarriers[%idx] : tensor<4xi64>, index -> !tt.ptr + + }]; + + let assemblyFormat = "$tensor `[` $index `]` attr-dict `:` type($tensor) `,` type($index) `->` type($result)"; + + let arguments = (ins I64Tensor:$tensor, I32:$index); + + let results = (outs TT_Ptr:$result); +} + +def TTNG_MBarrierWaitOp : TTNG_Op<"mbarrier_wait", [MemoryEffects<[MemRead, MemWrite]>]> { + let summary = "mbarrier wait"; + + let description = [{ + This operation defining the waiting action for a mbarrier. + The subsequent operations should not execute until this operation completes waiting. + + Example: + + triton_nvidia_gpu.mbarrier_wait %0, %1 : !tt.ptr + + }]; + + let arguments = (ins TT_Ptr:$mbarrier, I1: $phase); + + let assemblyFormat = "$mbarrier `,` $phase attr-dict `:` type($mbarrier)"; +} + +def TTNG_MBarrierArriveOp : TTNG_Op<"mbarrier_arrive", [AttrSizedOperandSegments, + MemoryEffects<[MemWrite]>]> { + let summary = "mbarrier arrive"; + + let description = [{ + This operation defining the arriving action for a mbarrier. + txCount: + An optional attribute that set tx-count. This Op will be lowered into + mbarrier.arrive.expect_tx if the optional attribute exist. + trackAsyncOp: + If true, this op will be lowered into cp.async.mbarrier.arrive.noinc. + pred: + Only perform arrive action when pred is true. + remoteCtaId: + if set, perform an remote arrive action. + + Example: + + triton_nvidia_gpu.mbarrier_arrive %0 {trackAsyncOp = false} : !tt.ptr + + }]; + + let arguments = (ins TT_Ptr:$mbarrier, + Optional:$pred, + Optional:$remoteCtaId, + I1Attr: $trackAsyncOp, + DefaultValuedAttr: $txCount + ); + + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + +def TTNG_FenceAsyncSharedOp : TTNG_Op<"fence_async_shared"> { + let arguments = (ins BoolAttr:$bCluster); + + let summary = "fence proxy async"; + + let assemblyFormat = "attr-dict"; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 90; + } + }]; +} + +// TODO[goostavz]: ThreadId & ClusterCTAId should not be exposed to +// ttgpu level. Remove them when async dialect is ready. +def TTNG_GetThreadIdOp : TTNG_Op<"get_thread_id", [Pure]> { + let description = [{ + Returns the one dimensional threadId. + }]; + + let results = (outs I32:$result); + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def TTNG_GetClusterCTAIdOp : TTNG_Op<"get_cluster_cta_id", [Pure]> { + let description = [{ + Returns the one dimensional cluster_cta_id. + }]; + + let results = (outs I32:$result); + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def TTNG_NamedBarrierArriveOp : TTNG_Op<"bar_arrive", []> { + let summary = "named barrier arrive"; + + let arguments = (ins I32:$bar, I32: $numThreads); + + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + +def TTNG_NamedBarrierWaitOp : TTNG_Op<"bar_wait", []> { + let summary = "named barrier wait"; + + let arguments = (ins I32:$bar, I32: $numThreads); + + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + +def TTNG_InsertSliceAsyncV2Op : TTNG_Op<"insert_slice_async_v2", + [AttrSizedOperandSegments, + ResultsAreSharedEncoding, + // TODO: Check if MemWrite will degrade performance of non-warp-specialized kernel + MemoryEffects<[MemRead, MemWrite]>]> { + + let arguments = (ins AnyTypeOf<[TT_Ptr, TT_PtrTensor]>:$src, TT_Tensor:$dst, + I32:$index, TT_Ptr:$mbar, + Optional>:$mask, Optional:$other, + TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict, + BoolAttr:$isVolatile, I32Attr:$axis); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)"; +} + +// TODO: the abstraction of barriers in ttgpu level is pending, will revisit later +// def TTNG_AwaitOp : TTNG_Op<"await", []> { +// let arguments = (ins TTNG_TokenType:$token); +// let assemblyFormat = "$token attr-dict `:` type($token)"; +// } + +def TTNG_ClusterArriveOp : TTNG_Op<"cluster_arrive", []> { + let arguments = (ins I1Attr:$relaxed); + let assemblyFormat = "attr-dict"; +} + +def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> { + let assemblyFormat = "attr-dict"; +} + +// +// DotAsync Op +// +def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot async"; + + let description = [{ + $d = matrix_multiply($a, $b) + $c + }]; + + let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32); + + let results = (outs TT_FpIntTensor:$d); + + let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)"; +} + +def TTNG_DotWaitOp : TTNG_Op<"dot_wait", []> { + let summary = "dot wait"; + + let description = [{ + This operation defining the waiting action for a async dot, MMAv3 .e.g. + The subsequent operations should not execute until this operation completes waiting. + }]; + + let arguments = (ins I32Attr:$pendings); + + let assemblyFormat = "attr-dict"; +} + +def TTNG_StoreAsyncOp : TTNG_Op<"store_async", + [MemoryEffects<[MemWrite]>]> { + let summary = "store asynchronous by a tensor pointer"; + let arguments = (ins TT_TensorPtr:$dst, TT_Tensor:$src, + DefaultValuedAttr:$cache); + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + +def TTNG_GetAgentIdOp : TTNG_Op<"get_agent_id", [Pure]> { + let results = (outs I32:$result); + + let builders = [OpBuilder<(ins)>]; + + let assemblyFormat = "attr-dict `:` type($result)"; +} + +// +// Token +// + +def TTNG_CreateTokenOp : TTNG_Op<"create_token"> { + let results = (outs TensorOf<[TTNG_TokenType]>:$result); + + let arguments = (ins I32Attr:$num); + + let builders = [OpBuilder<(ins "uint32_t":$num)>]; + + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def TTNG_ProducerAcquireOp : TTNG_Op<"producer_acquire"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx); + + let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)"; +} + +def TTNG_ProducerCommitOp : TTNG_Op<"producer_commit"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx); + + let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)"; +} + +def TTNG_ConsumerWaitOp : TTNG_Op<"consumer_wait"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx); + + let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)"; +} + +def TTNG_ConsumerReleaseOp : TTNG_Op<"consumer_release"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx); + + let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)"; +} + +// +// Mutex +// + +def TTNG_GetMutexRoleIdOp : TTNG_Op<"get_mutex_role_id"> { + let results = (outs I32:$result); + + let arguments = (ins I32Attr:$num); + + let builders = [OpBuilder<(ins "uint32_t":$num)>]; + + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def TTNG_CreateMutexOp : TTNG_Op<"create_mutex"> { + let results = (outs TTNG_MutexType:$result); + + let builders = [OpBuilder<(ins)>]; + + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def TTNG_LockOp : TTNG_Op<"lock"> { + let arguments = (ins TTNG_MutexType:$mutex); + + let assemblyFormat = "$mutex attr-dict `:` type(operands)"; +} + +def TTNG_UnlockOp : TTNG_Op<"unlock"> { + let arguments = (ins TTNG_MutexType:$mutex); + + let assemblyFormat = "$mutex attr-dict `:` type(operands)"; +} + +def TTNG_RegAllocOp : TTNG_Op<"reg_alloc", []> { + let summary = "register allocation"; + + let arguments = (ins I32Attr: $regCount); + + let assemblyFormat = "$regCount attr-dict"; +} + +def TTNG_RegDeallocOp : TTNG_Op<"reg_dealloc", []> { + let summary = "register deallocation"; + + let arguments = (ins I32Attr: $regCount); + + let assemblyFormat = "$regCount attr-dict"; +} + +#endif diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td new file mode 100644 index 000000000000..d3126f8a044e --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td @@ -0,0 +1,37 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_TYPES +#define TRITONNVIDIAGPU_TYPES + +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td" +include "mlir/IR/AttrTypeBase.td" + +class TTNG_TypeDef + : TypeDef { + let mnemonic = _mnemonic; +} + +def TTNG_TokenType : TTNG_TypeDef<"Token", "token">; + +def TTNG_MutexType : TTNG_TypeDef<"Mutex", "mutex">; + +#endif diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/Types.h b/include/triton/Dialect/TritonNvidiaGPU/IR/Types.h new file mode 100644 index 000000000000..63c7a091afcd --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/Types.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITONNVIDIAGPU_IR_TYPES_H_ +#define TRITONNVIDIAGPU_IR_TYPES_H_ + +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h.inc" + +#endif // TRITON_IR_TYPES_H_ diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt b/include/triton/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000000..d4b5c097f4fe --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonNvidiaGPU) +add_public_tablegen_target(TritonNvidiaGPUTransformsIncGen) diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h new file mode 100644 index 000000000000..9d3fd70890c7 --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +// Used by Triton runtime +struct ClusterInfo { + ClusterInfo() : clusterDimX(1), clusterDimY(1), clusterDimZ(1) {} + int clusterDimX; + int clusterDimY; + int clusterDimZ; +}; + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir + +namespace mlir { + +std::unique_ptr +createTritonNvidiaGPUMaterializeLoadStorePass(int numWarps = 4, + int computeCapability = 80); + +std::unique_ptr createTritonNvidiaGPUPlanCTAPass( + mlir::triton::nvidia_gpu::ClusterInfo *clusterInfo = nullptr); + +std::unique_ptr +createTritonNvidiaGPUWSFeasibilityCheckingPass(int computeCapability = 90); + +std::unique_ptr +createTritonNvidiaGPUWSDecomposingPass(int computeCapability = 90); + +std::unique_ptr +createTritonNvidiaGPUWSPipelinePass(int numStages = 3, int numWarps = 4, + int computeCapability = 90); + +std::unique_ptr +createTritonNvidiaGPUWSMutexPass(int computeCapability = 90); + +std::unique_ptr +createTritonNvidiaGPUWSMaterializationPass(int computeCapability = 90); + +std::unique_ptr +createTritonNvidiaGPUFenceInsertionPass(int computeCapability = 90); + +std::unique_ptr +createTritonGPURewriteTensorPointerPass(int computeCapability = 80); + +std::unique_ptr createTritonNvidiaGPUWSFixupMissingAttrs(); + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +} // namespace mlir +#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_ diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td new file mode 100644 index 000000000000..d038c610f999 --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td @@ -0,0 +1,246 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_PASSES +#define TRITONNVIDIAGPU_PASSES + +include "mlir/Pass/PassBase.td" + +def MaterializeLoadStore : Pass<"triton-nvidia-gpu-materialize-load-store", "mlir::ModuleOp"> { + let summary = "materialize load & store"; + + let description = [{ + This pass works after pipeline pass, converting the remaining tt.LoadOp taking + ptr as input into ttg.InsertSliceAsyncOp and emit proper barriers + }]; + + let constructor = "mlir::createTritonNvidiaGPUMaterializeLoadStorePass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; + + let options = [ + Option<"numWarps", "num-warps", + "int32_t", /*default*/"4", + "number of warps per block">, + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"80", + "device compute capability"> + ]; +} + +def TritonGPUPlanCTAPass : Pass<"triton-nvidia-gpu-plan-cta", "mlir::ModuleOp"> { + let summary = "plan CTA"; + + let description = [{ + Plan CTAs in CGA + }]; + + let constructor = "mlir::createTritonNvidiaGPUPlanCTAPass()"; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; +} + +def TritonGPUWSFeasibilityChecking : Pass<"triton-nvidia-gpu-ws-feasibility-checking", "mlir::ModuleOp"> { + let summary = "Attach attr named TritonNvidiaGPUDialect::getWSSupportedAttrName() if auto WS supported"; + + let description = [{ + Since not every legal triton kernels can be auto WS, this pass does some (conservative) check + and attaches an attribute named TritonNvidiaGPUDialect::getWSSupportedAttrName() on + the input module op if the kernel is supported. + }]; + + let constructor = "mlir::createTritonNvidiaGPUWSFeasibilityCheckingPass()"; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"90", + "device compute capability"> + ]; +} + +def TritonGPUWSDecomposing : Pass<"triton-nvidia-gpu-ws-decomposing", "mlir::ModuleOp"> { + let summary = "Clustering on the ops according to their performance hotspots"; + + let description = [{ + Based on compute capability and heuristics, + this pass will identify some operations to be executed in different agents, + by marking them with async 'label'. E.g., + input: + %1 = tt,load %0 ... + %4 = tt.dot %1, %2, %3 ... + output: + %1 = tt,load %0 {async_agent = 0} ... + %4 = tt.dot %1, %2, %3 {async_agent = 1} : ... + }]; + + let constructor = "mlir::createTritonNvidiaGPUWSDecomposingPass()"; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"80", + "device compute capability"> + ]; +} + +def TritonGPUWSPipeline : Pass<"triton-nvidia-gpu-ws-pipeline", "mlir::ModuleOp"> { + let summary = "Warp specialization pipeline"; + + let description = [{ + }]; + + let constructor = "mlir::createTritonNvidiaGPUWSPipelinePass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; + + let options = [ + Option<"numStages", "num-stages", + "int32_t", /*default*/"3", + "number of pipeline stages">, + Option<"numWarps", "num-warps", + "int32_t", /*default*/"12", + "number of warps per block">, + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"90", + "device compute capability"> + ]; +} + +def TritonGPUWSMutex : Pass<"triton-nvidia-gpu-ws-mutex", "mlir::ModuleOp"> { + let summary = "Warp specialization mutex syncronization"; + + let description = [{ + create mutex syncronization for persistent kernel. (as "2 Math WG" persistent kernel in cutlass) + For example, the agent containing dot and store will be divided into two sub-agent, + which execute dot and store alternately. i.e.: + sub-agent-0: dot | store | dot | ... | store + sub-agent-1: | dot | store | ... | dot | store + }]; + + let constructor = "mlir::createTritonNvidiaGPUWSMutexPass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"80", + "device compute capability"> + ]; +} + +def TritonGPUWSMaterialization : Pass<"triton-nvidia-gpu-ws-materialization", "mlir::ModuleOp"> { + let summary = "Warp specialization materialization"; + + let description = [{ + }]; + + let constructor = "mlir::createTritonNvidiaGPUWSMaterializationPass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"90", + "device compute capability"> + ]; +} + +def TritonGPUFenceInsertion : Pass<"triton-nvidia-gpu-fence-insertion", "mlir::ModuleOp"> { + let summary = "Insert fences across generic and async proxy"; + + let description = [{ + }]; + + let constructor = "mlir::createTritonNvidiaGPUFenceInsertionPass()"; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"90", + "device compute capability"> + ]; +} + +def TritonGPURewriteTensorPointer : Pass { + let summary = "Rewrite load/stores with tensor pointers into legacy load/stores"; + let description = [{ + This pass rewrites all load/store semantics initiated by a `tt.make_tensor_ptr` and `tt.advance` into legacy + semantics. After this pass, `tt.make_tensor_ptr` and `tt.advance` will disappear, and it generates logics to compute + the pointer/mask/other for each load/store. + }]; + + let constructor = "mlir::createTritonGPURewriteTensorPointerPass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"80", + "device compute capability"> + ]; +} + +def TritonGPUWSFixupMissingAttrs : Pass<"triton-nvidia-gpu-ws-fixup-missing-attrs", "mlir::ModuleOp"> { + let summary = "Fixup missing WS related attributes"; + + let description = [{ + WS related attributes are attached to some key operations and are used when lowering to llvm. + However these attributes maybe be dropped in the following IR transform. This pass tries to + fixup the missing attributes. + }]; + + let constructor = "mlir::createTritonNvidiaGPUWSFixupMissingAttrs()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; +} + + +#endif diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h new file mode 100644 index 000000000000..7090e937708f --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_ +#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_ + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/MapVector.h" + +namespace mlir { + +// 0 is reserved for default sync. +// TODO: comprehensive mechanism to globally manage namedbarrier. +static int const nameBarrierIdBegin = 1; +static int nameBarrierIdEnd = 16; + +/// Helper functions for async agent +typedef int AgentId; +SmallVector getAgentIds(Operation *op); +bool hasAgentId(Operation *op, AgentId agentId); +void setAgentIds(Operation *op, ArrayRef agentIds); +SmallVector collectAgentIds(Operation *op); +void addAgentIds(Operation *op, ArrayRef agents); +SmallVector getMutexBarIds(Operation *op); +SmallVector getMutexNumThreads(Operation *op); + +class OpBuilderWithAgentIds : public OpBuilder { +public: + OpBuilderWithAgentIds(MLIRContext *context) : OpBuilder(context) {} + + void setAgentIdsFromArray(ArrayRef newAgentIds) { + agentIds = SmallVector(newAgentIds.begin(), newAgentIds.end()); + } + + void setAgentIdsFromOp(Operation *op) { + setAgentIdsFromArray(getAgentIds(op)); + } + + void setAgentIdsFromValueUsers(Value value) { + SetVector agentIdSet; + for (Operation *user : value.getUsers()) + for (AgentId agentId : getAgentIds(user)) + agentIdSet.insert(agentId); + setAgentIdsFromArray(agentIdSet.getArrayRef()); + } + + template + OpTy createWithAgentIds(Args &&...args) { + OpTy op = create(std::forward(args)...); + if (!agentIds.empty()) + setAgentIds(op, agentIds); + return op; + } + +private: + SmallVector agentIds; +}; + +/// Constant agent ids +constexpr AgentId kLoadAgentId = 0; +constexpr AgentId kDotAgentId = 1; + +bool isWSCandidateLoad(Operation *op); +bool isWSSupported(ModuleOp m, int computeCapability); + +LogicalResult getDependentValues(Value val, DenseSet &depSet, + const DenseSet &stopSet = {}); +LogicalResult getDependentValues(Operation *op, DenseSet &depSet, + const DenseSet &stopSet = {}); +DenseSet getDependentOps(DenseSet &depSet); + +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_ diff --git a/include/triton/Target/AMDGCN/AMDGCNTranslation.h b/include/triton/Target/AMDGCN/AMDGCNTranslation.h new file mode 100644 index 000000000000..c20f1924db5c --- /dev/null +++ b/include/triton/Target/AMDGCN/AMDGCNTranslation.h @@ -0,0 +1,19 @@ +#ifndef TRITON_TARGET_AMDGCNTRANSLATION_H +#define TRITON_TARGET_AMDGCNTRANSLATION_H + +#include +#include + +namespace llvm { +class Module; +} // namespace llvm + +namespace triton { + +// Translate LLVM IR to AMDGCN code. +std::tuple +translateLLVMIRToAMDGCN(llvm::Module &module, std::string cc); + +} // namespace triton + +#endif diff --git a/include/triton/Target/LLVMIR/LLVMIRTranslation.h b/include/triton/Target/LLVMIR/LLVMIRTranslation.h index 05faee277b44..84adda2e46fd 100644 --- a/include/triton/Target/LLVMIR/LLVMIRTranslation.h +++ b/include/triton/Target/LLVMIR/LLVMIRTranslation.h @@ -1,5 +1,7 @@ #ifndef TRITON_TARGET_LLVM_IR_LLVM_IR_TRANSLATION_H #define TRITON_TARGET_LLVM_IR_LLVM_IR_TRANSLATION_H +#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h" +#include "triton/Target/PTX/TmaMetadata.h" #include "llvm/ADT/StringRef.h" #include #include @@ -26,12 +28,16 @@ void addExternalLibs(mlir::ModuleOp &module, std::unique_ptr translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module, int computeCapability, - bool isROCM); + mlir::triton::gpu::TMAMetadataTy &tmaInfos, + Target target, int wavesPerEU); // Translate mlir LLVM dialect to LLVMIR, return null if failed. std::unique_ptr translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module, - bool isROCM); + Target target); + +bool linkExternLib(llvm::Module &module, llvm::StringRef name, + llvm::StringRef path, Target target); } // namespace triton } // namespace mlir diff --git a/include/triton/Target/PTX/TmaMetadata.h b/include/triton/Target/PTX/TmaMetadata.h new file mode 100644 index 000000000000..f183f4e5b0fb --- /dev/null +++ b/include/triton/Target/PTX/TmaMetadata.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_TARGET_PTX_TMAMETADATA_H +#define TRITON_TARGET_PTX_TMAMETADATA_H + +#include "python/triton/third_party/cuda/include/cuda.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/FormatVariadic.h" +#include +#include +#include + +namespace mlir { +namespace triton { +namespace gpu { + +struct TMAInfo { + // -------------------------------------------- + // informations to be filled into CUtensorMaps + int tensorDataType; + + uint32_t tensorRank; + + // the argument indices for the runtime to get globalAddresses + size_t globalAddressArgIdx; + + // the argument indices for the runtime to get globalDims, -1 stands for this + // dim is padded + std::vector globalDimsArgIdx; + + // the argument indices for the runtime to get globalStrides, -1 stands for + // this dim is padded the runtime need to map the value to internal format + std::vector globalStridesArgIdx; + + std::vector boxDims; + + std::vector elementStrides; + + int interleave; + + int swizzle; + + int l2Promotion; + + int oobFill; + + // -------------------------------------------- + // the argument indices for the runtime to send the address of tma_desc to the + // binary + int TMADescArgIdx; + + template + void dump_vec(const std::vector &vec, llvm::StringRef info) const { + llvm::errs() << info << ": "; + for (const T &e : vec) + llvm::errs() << e << ","; + llvm::errs() << "\n"; + } + + void dump() const { + llvm::errs() << "TMA Info: ----------" + << "\n"; + llvm::errs() << "-- tensorDataType: " << tensorDataType + << ", tensorRank: " << tensorRank << "\n"; + llvm::errs() << "-- globalAddressArgIdx: " << globalAddressArgIdx << "\n"; + llvm::errs() << "-- TMADescArgIdx: " << TMADescArgIdx << "\n"; + dump_vec(globalDimsArgIdx, "-- globalDimsArgIdx"); + dump_vec(globalStridesArgIdx, "-- globalStridesArgIdx"); + dump_vec(boxDims, "-- boxDims"); + dump_vec(elementStrides, "-- elementStrides"); + llvm::errs() << "-- interleave: " << interleave << "\n"; + llvm::errs() << "-- swizzle: " << swizzle << "\n"; + llvm::errs() << "-- l2Promotion: " << l2Promotion << "\n"; + llvm::errs() << "-- oobFill: " << oobFill << "\n"; + }; +}; + +using TMAMetadataTy = std::vector; + +} // namespace gpu +} // namespace triton +} // namespace mlir + +#endif // TRITON_TARGET_PTX_TMAMETADATA_H diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 7dd960070149..53e421ef218b 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -24,10 +24,16 @@ #include #include +#include #include namespace triton { +const std::set ENV_VARS = { + "ENABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION", + "ENABLE_TMA", "MLIR_ENABLE_DUMP", "LLVM_IR_ENABLE_DUMP", + "AMDGCN_ENABLE_DUMP"}; + namespace tools { inline std::string getenv(const char *name) { @@ -39,6 +45,9 @@ inline std::string getenv(const char *name) { } inline bool getBoolEnv(const std::string &env) { + std::string msg = "Environment variable " + env + " is not recognized"; + assert(triton::ENV_VARS.find(env.c_str()) != triton::ENV_VARS.end() && + msg.c_str()); const char *s = std::getenv(env.c_str()); std::string str(s ? s : ""); std::transform(str.begin(), str.end(), str.begin(), diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp index bc9fb637f6b9..db00c16d1600 100644 --- a/lib/Analysis/Alias.cpp +++ b/lib/Analysis/Alias.cpp @@ -2,6 +2,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" namespace mlir { @@ -27,17 +28,21 @@ void SharedMemoryAliasAnalysis::visitOperation( // These ops may allocate a new shared memory buffer. auto result = op->getResult(0); // XXX(Keren): the following ops are always aliasing for now - if (isa(op)) { + if (isa(op)) { // extract_slice %src // trans %src aliasInfo = AliasInfo(operands[0]->getValue()); pessimistic = false; - } else if (isa( - op)) { + } else if (isa(op)) { // insert_slice_async %src, %dst, %index // insert_slice %src into %dst[%offsets] aliasInfo = AliasInfo(operands[1]->getValue()); pessimistic = false; + } else if (isa(op)) { + aliasInfo = AliasInfo(operands[0]->getValue()); + pessimistic = false; } else if (triton::gpu::isSharedEncoding(result)) { aliasInfo.insert(result); pessimistic = false; diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 279b79cb8932..94d0b633704a 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -16,6 +16,7 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::getContigPerThread; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getShapePerCTATile; using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::MfmaEncodingAttr; using ::mlir::triton::gpu::MmaEncodingAttr; @@ -58,11 +59,23 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - // MmaToDotShortcut doesn't use shared mem - if (srcLayout.isa() && - dstLayout.isa()) - if (isMmaToDotShortcut(srcTy, dstTy)) - return {}; + if (shouldUseDistSmem(srcLayout, dstLayout)) { + // TODO: padding to avoid bank conflicts + return convertType(getShapePerCTA(srcTy)); + } + + // MmaToDotShortcut and MmaToMmaShortcut doesn't use shared mem + if (auto srcMmaLayout = srcLayout.dyn_cast()) { + if (dstLayout.isa()) { + if (isMmaToDotShortcut(srcTy, dstTy)) { + return {}; + } + } else if (auto dstMmaLayout = dstLayout.dyn_cast()) { + if (isMmaToMmaShortcut(srcTy, dstTy)) { + return {}; + } + } + } #ifdef USE_ROCM if (srcLayout.isa() && @@ -82,18 +95,18 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread; outVec = outOrd[0] == 0 ? 1 : dstContigPerThread; - auto srcShape = srcTy.getShape(); - auto dstShape = dstTy.getShape(); - auto srcShapePerCTA = getShapePerCTA(srcLayout, srcShape); - auto dstShapePerCTA = getShapePerCTA(dstLayout, dstShape); + auto srcShapePerCTA = getShapePerCTA(srcTy); + auto dstShapePerCTA = getShapePerCTA(dstTy); + auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); + auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape()); unsigned rank = dstTy.getRank(); SmallVector paddedRepShape(rank); unsigned pad = std::max(inVec, outVec); for (unsigned d = 0; d < rank; ++d) { paddedRepShape[d] = - std::max(std::min(srcTy.getShape()[d], srcShapePerCTA[d]), - std::min(dstTy.getShape()[d], dstShapePerCTA[d])); + std::max(std::min(srcShapePerCTA[d], srcShapePerCTATile[d]), + std::min(dstShapePerCTA[d], dstShapePerCTATile[d])); } if (rank == 1) return paddedRepShape; @@ -105,6 +118,12 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, return paddedRepShape; } +SmallVector +getScratchConfigForStoreAsync(triton::nvidia_gpu::StoreAsyncOp op) { + auto srcTy = op.getSrc().getType().cast(); + return convertType(getShapePerCTA(srcTy)); +} + // TODO: extend beyond scalars SmallVector getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) { SmallVector smemShape; @@ -137,11 +156,68 @@ class AllocationAnalysis { using BufferT = Allocation::BufferT; /// Value -> Liveness Range + using IntervalT = Interval; /// Use MapVector to ensure determinism. - using BufferRangeMapT = llvm::MapVector>; + using BufferRangeMapT = llvm::MapVector; /// Nodes -> Nodes using GraphT = DenseMap>; + /// Set of Liveness Intervals + class LivenessR : public SmallVector { + public: + LivenessR() = default; + LivenessR(const LivenessR &) = default; + + /// Disjointness + bool isDisjoint() const { + if (size() < 2) + return false; + // sorted so the first OOB proves disjoint + auto maxId = (*this)[0].end(); + for (auto rng : *this) { + if (rng.start() <= maxId) { + // adjoining + maxId = std::max(maxId, rng.end()); + } else + return true; + } + return false; + } + + void sort() { + llvm::sort(*this, [](const auto &lhs, const auto &rhs) { + return lhs.start() <= rhs.start(); + }); + } + + bool addAdjacent(size_t id) { + bool isAdjacent = false; + for (auto &interval : *this) { + if (interval.adjacent(id)) { + isAdjacent = true; + interval = interval.merge(IntervalT(id)); + } + } + return isAdjacent; + } + + void add(size_t id) { + if (!addAdjacent(id)) + push_back(IntervalT(id)); + } + IntervalT unionize() const { + IntervalT res; + if (size()) { + res = front(); + for (auto &I : *this) + res = res.merge(I); + } + return res; + } + }; + + typedef function_ref LivenessF; + void run() { getValuesAndSizes(); resolveLiveness(); @@ -155,20 +231,45 @@ class AllocationAnalysis { // For example: %a = scf.if -> yield // %a must be allocated elsewhere by other operations. // FIXME(Keren): extract and insert are always alias for now - if (!maybeSharedAllocationOp(op) || maybeAliasOp(op)) { + if (!maybeSharedAllocationOp(op) || maybeAliasOp(op)) return; - } + // XXX(Keren): Why this hard-coded alignment? + size_t kAlignment = 8; for (Value result : op->getResults()) { if (triton::gpu::isSharedEncoding(result)) { // Bytes could be a different value once we support padding or other // allocation policies. auto tensorType = result.getType().dyn_cast(); - auto bytes = tensorType.getNumElements() * + auto shapePerCTA = triton::gpu::getShapePerCTA(tensorType); + auto bytes = product(shapePerCTA) * tensorType.getElementTypeBitWidth() / 8; - allocation->addBuffer(result, bytes); + + // XXX(Keren): magic numbers 256 and 1024 + // benzh@maybe alignment should be passed in. + // Software swizzling calculates phase based on offset, while hardware + // swizzling do that based on physical address. Thus only by setting the + // alignment to 1024 can ensure the correctness.  + if (bytes > 256) + kAlignment = 1024; + allocation->addBuffer(result, bytes, + kAlignment); } } + if (isa(op)) { + Value result = op->getResult(0); + if (!result.getType().isa()) + // In case AllocMBarrierOp is allocating scalar mbarriers + allocation->addBuffer(result, 8, + kAlignment); + } + } + + template + void maybeAddScratchBuffer(Operation *op, unsigned bytes, + unsigned alignment) { + if (bytes > 0) + allocation->addBuffer(op, bytes, alignment); } template @@ -179,14 +280,17 @@ class AllocationAnalysis { /// Initializes temporary shared memory for a given operation. void getScratchValueSize(Operation *op) { + const size_t scratchAlignment = 128; if (auto reduceOp = dyn_cast(op)) { ReduceOpHelper helper(reduceOp); unsigned bytes = helper.getScratchSizeInBytes(); - maybeAddScratchBuffer(op, bytes); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); } else if (auto scanOp = dyn_cast(op)) { ScanLoweringHelper helper(scanOp); unsigned bytes = helper.getScratchSizeInBytes(); - maybeAddScratchBuffer(op, bytes); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); } else if (auto cvtLayout = dyn_cast(op)) { auto srcTy = cvtLayout.getSrc().getType().cast(); auto dstTy = cvtLayout.getResult().getType().cast(); @@ -210,7 +314,20 @@ class AllocationAnalysis { srcTy.getElementType().isa() ? elems * kPtrBitWidth / 8 : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; - maybeAddScratchBuffer(op, bytes); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto storeAsyncOp = + dyn_cast(op)) { + auto srcTy = storeAsyncOp.getSrc().getType().cast(); + auto srcEncoding = srcTy.getEncoding(); + if (!srcEncoding.isa()) { + return; + } + auto smemShape = getScratchConfigForStoreAsync(storeAsyncOp); + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, + std::multiplies{}); + auto bytes = elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; + maybeAddScratchBuffer(op, bytes, 1024); } else if (auto atomicRMWOp = dyn_cast(op)) { auto value = op->getOperand(0); // only scalar requires scratch memory @@ -227,7 +344,8 @@ class AllocationAnalysis { elemTy.isa() ? elems * kPtrBitWidth / 8 : elems * std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; - maybeAddScratchBuffer(op, bytes); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); } } else if (auto atomicCASOp = dyn_cast(op)) { auto value = op->getOperand(0); @@ -239,13 +357,15 @@ class AllocationAnalysis { auto bytes = elemTy.isa() ? elems * kPtrBitWidth / 8 : elems * elemTy.getIntOrFloatBitWidth() / 8; - maybeAddScratchBuffer(op, bytes); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); } else if (auto callOp = dyn_cast(op)) { auto callable = callOp.resolveCallable(); auto funcOp = dyn_cast(callable); auto *funcAlloc = &(*funcAllocMap)[funcOp]; auto bytes = funcAlloc->getSharedMemorySize(); - maybeAddScratchBuffer(op, bytes); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); } } @@ -289,33 +409,55 @@ class AllocationAnalysis { /// Computes the liveness range of the allocated value. /// Each buffer is allocated only once. - void resolveExplicitBufferLiveness( - function_ref(Value value)> getLiveness) { + void resolveExplicitBufferLiveness(LivenessF getLiveness) { for (auto valueBufferIter : allocation->valueBuffer) { auto value = valueBufferIter.first; auto *buffer = valueBufferIter.second; - bufferRange[buffer] = getLiveness(value); + auto ranges = getLiveness(value); + bufferRange[buffer] = ranges.unionize(); } } /// Extends the liveness range by unionizing the liveness range of the aliased /// values because each allocated buffer could be an alias of others, if block /// arguments are involved. - void resolveAliasBufferLiveness( - function_ref(Value value)> getLiveness) { + /// Only unionize adjacent live ranges to account for loop-carried buffers that + /// are mutually exclusive. + /// Example from stream pipeliner: + /// 3 %b0 = convert_layout %g0 -+ + /// 4 %fr = for (.., %arg0 = %b0) { | + /// 5 %gn = load %pc | + /// 6 %bc = convert_layout %arg0 -+ + /// 7 %v = add %bc, ... + /// 8 %bn = convert_layout %gn -+ + /// 9 %pn = addptr %pc, %cst | + /// 10 } | + /// 11 %be = convert_layout %fr#1 -+ + /// 12 %ve = add %be + void resolveAliasBufferLiveness(LivenessF getLiveness) { for (auto aliasBufferIter : allocation->aliasBuffer) { auto value = aliasBufferIter.first; auto buffers = aliasBufferIter.second; - auto range = getLiveness(value); + auto aranges = getLiveness(value); + bool disjoint = aranges.isDisjoint(); for (auto *buffer : buffers) { - auto minId = range.start(); - auto maxId = range.end(); + auto range = aranges[0]; if (bufferRange.count(buffer)) { - // Extend the allocated buffer's range - minId = std::min(minId, bufferRange[buffer].start()); - maxId = std::max(maxId, bufferRange[buffer].end()); + auto brange = bufferRange[buffer]; + if (disjoint) { + // find adjacent/intersecting + for (auto arange : aranges) { + if (arange.adjacent(brange) || + arange.intersects(brange)) + brange = arange.merge(brange); + } + range = brange; + } else { + // Extend the allocated buffer's range + range = range.merge(brange); + } } - bufferRange[buffer] = Interval(minId, maxId); + bufferRange[buffer] = range; } } } @@ -365,19 +507,22 @@ class AllocationAnalysis { // Analyze liveness of explicit buffers Liveness liveness(operation); auto getValueLivenessRange = [&](Value value) { + LivenessR ranges; + // Shared memory allocated by mbarrier cannot be reused + if (value.getDefiningOp() && + isa(value.getDefiningOp())) { + ranges.push_back(Interval(std::numeric_limits::min(), + std::numeric_limits::max())); + return ranges; + } + auto liveOperations = liveness.resolveLiveness(value); - auto minId = std::numeric_limits::max(); - auto maxId = std::numeric_limits::min(); std::for_each(liveOperations.begin(), liveOperations.end(), [&](Operation *liveOp) { - if (operationId[liveOp] < minId) { - minId = operationId[liveOp]; - } - if ((operationId[liveOp] + 1) > maxId) { - maxId = operationId[liveOp] + 1; - } + ranges.add(operationId[liveOp]); }); - return Interval(minId, maxId); + ranges.sort(); + return ranges; }; resolveExplicitBufferLiveness(getValueLivenessRange); @@ -432,9 +577,9 @@ class AllocationAnalysis { // If the available triple's range is less than a given buffer range, // we won't know if there has been an overlap without using graph coloring. // Start -> Liveness Range - using TripleMapT = std::multimap>; + using TripleMapT = std::multimap; TripleMapT tripleMap; - tripleMap.insert(std::make_pair(0, Interval())); + tripleMap.insert(std::make_pair(0, IntervalT())); SmallVector xBuffers = buffers; while (!xBuffers.empty()) { auto tripleIt = tripleMap.begin(); @@ -446,17 +591,22 @@ class AllocationAnalysis { auto xRange = bufferRange[buffer]; bool res = xRange.intersects(range); for (auto val : tripleMap) - res = res && !val.second.intersects(xRange); + res = res && + !val.second.intersects(xRange); // only one buffer intersect return res; }); if (bufferIt != xBuffers.end()) { auto buffer = *bufferIt; auto xSize = buffer->size; auto xRange = bufferRange.lookup(buffer); - bufferStart[buffer] = size; - tripleMap.insert( - {size + xSize, Interval{std::max(range.start(), xRange.start()), - std::min(range.end(), xRange.end())}}); + // TODO(Keren): A buffer's size shouldn't be determined here, have to + // clean it up + size_t alignment = buffer->alignment; + size_t alignSize = ((size + alignment - 1) / alignment) * alignment; + bufferStart[buffer] = alignSize; + tripleMap.insert({alignSize + xSize, + Interval{std::max(range.start(), xRange.start()), + std::min(range.end(), xRange.end())}}); // We could either insert (range.start, xRange.start) or (range.start, // xRange.end), both are correct and determine the potential buffer // offset, and the graph coloring algorithm will solve the interference, @@ -542,6 +692,19 @@ class AllocationAnalysis { } } + void dump() const { + llvm::outs() << "DUMP: " << "\n"; + for (auto bufferIter : bufferRange) { + + llvm::outs() << "ID= " << bufferIter.first->id << "\n"; + // llvm::outs() << " Kind= " << kind << "\n"; + llvm::outs() << " Size= " << bufferIter.first->size << "\n"; + llvm::outs() << " Offs= " << bufferIter.first->offset << "\n"; + llvm::outs() << " -> " << bufferIter.second.start() << "\n"; + llvm::outs() << " -> " << bufferIter.second.end() << "\n"; + } + } + private: Operation *operation; Allocation::FuncAllocMapT *funcAllocMap; diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 1b6e0fb81656..dd782b1876eb 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -667,14 +667,10 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl { AxisInfo getAxisInfo(OpTy op, ArrayRef *> operands) override { - auto resTy = op.getResult().getType().template dyn_cast(); - if (!resTy) - return AxisInfo(); - auto shape = resTy.getShape(); - auto rank = shape.size(); auto condConstancy = operands[0]->getValue().getConstancy(); auto lhsInfo = operands[1]->getValue(); auto rhsInfo = operands[2]->getValue(); + auto rank = lhsInfo.getRank(); AxisInfo::DimVectorT contiguity, divisibility, constancy; std::optional constantValue; diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index bb9a21ccf980..df1fe4066188 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -14,4 +14,5 @@ add_mlir_library(TritonAnalysis MLIRLLVMDialect TritonIR TritonGPUIR + TritonNvidiaGPUIR ) diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index c8db99d30490..90d37141349e 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -2,7 +2,12 @@ #include "triton/Analysis/Alias.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "../lib/Conversion/TritonGPUToLLVM/Utility.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" + #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include @@ -103,20 +108,37 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, return; } - if (isa(op)) { + // TODO(Keren): Don't expose LLVM Dialect ops here + if (isa(op) || + (isa(op) && + (dyn_cast(op).getAsmString().find("bar.sync") != + std::string::npos))) { // If the current op is a barrier, we sync previous reads and writes blockInfo->sync(); return; } - if (isa(op) && - !isa(op->getNextNode())) { + if (isa(op) && + !isa(op->getNextNode()) && + !(isa(op->getNextNode()) && + (dyn_cast(op->getNextNode()) + .getAsmString() + .find("bar.sync") != std::string::npos))) { // If the current op is an async wait and the next op is not a barrier we // insert a barrier op and sync blockInfo->sync(); OpBuilder::InsertionGuard g(*builder); builder->setInsertionPointAfter(op); - builder->create(op->getLoc()); + if (auto optionalAgentId = getWSAgentId(op)) { + int agentId = *optionalAgentId, roleId = 0; + if (auto optionalRoleId = getWSRoleId(op)) + roleId = *optionalRoleId; + int barId = agentId + roleId + nameBarrierIdBegin; + assert(barId < nameBarrierIdEnd); + barSync(*builder, op, barId, 128); + } else { + builder->create(op->getLoc()); + } blockInfo->sync(); return; } @@ -169,12 +191,23 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, if (blockInfo->isIntersected(curBlockInfo)) { OpBuilder::InsertionGuard g(*builder); builder->setInsertionPoint(op); - builder->create(op->getLoc()); + // TODO(Keren): Don't expose LLVM Dialect ops here + // TODO[shuhaoj]: Change hard code style of numThreads. Hide async_agent + // attr. Better way to determine barId (number of agents are limited). + if (auto optionalAgentId = getWSAgentId(op)) { + int agentId = *optionalAgentId, roleId = 0; + if (auto optionalRoleId = getWSRoleId(op)) + roleId = *optionalRoleId; + int barId = agentId + roleId + nameBarrierIdBegin; + assert(barId < nameBarrierIdEnd); + barSync(*builder, op, barId, 128); + } else { + builder->create(op->getLoc()); + } blockInfo->sync(); } // Update the region info, even if barrier is inserted, we have to maintain // the current op's read/write buffers. blockInfo->join(curBlockInfo); } - } // namespace mlir diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index acdb2c5582e1..6b758414896b 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -1,10 +1,14 @@ #include "triton/Analysis/Utility.h" #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Matchers.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Tools/Sys/GetEnv.hpp" #include @@ -37,6 +41,51 @@ bool ReduceOpHelper::isFastReduction() { getParentOrder(getSrcLayout())[0]; } +// Cases where distributed shared memory is not required in ConvertLayout: +// (1) numCTAs == 1 +// (2) numCTAs > 1 but srcCTALayout == dstCTALayout +// TODO: Case with SliceLayout as srcLayout and numCTAs > 1 is to be implemented +// in the future +bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) { + unsigned numCTAs = triton::gpu::getNumCTAs(srcLayout); + assert(numCTAs == triton::gpu::getNumCTAs(dstLayout) && + "Invalid layout conversion: the numbers of CTAs of src and dst " + "layouts are different"); + + // Case (1): Never use dsmem when numCTAs == 1 + if (numCTAs == 1) + return false; + + // Case where CTAsPerCGA of srcLayout in the sliced dim is not 1 is not + // implemented yet + if (auto sliceLayout = srcLayout.dyn_cast()) { + auto dim = sliceLayout.getDim(); + auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(sliceLayout.getParent()); + if (CTAsPerCGA[dim] != 1) + assert(0 && "Layout conversion to be implemented"); + } + + // Case where CTAsPerCGA of dstLayout in the sliced dim is not 1 is supported + if (auto sliceLayout = dstLayout.dyn_cast()) { + auto dim = sliceLayout.getDim(); + auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(sliceLayout.getParent()); + if (CTAsPerCGA[dim] != 1) + return true; + } + + // The above two branches make sure that it is legal to call getCTALayout of + // srcLayout and dstLayout + + // Case (2): Do not use dsmem when srcCTALayout == dstCTALayout + auto srcCTALayout = triton::gpu::getCTALayout(srcLayout); + auto dstCTALayout = triton::gpu::getCTALayout(dstLayout); + if (srcCTALayout == dstCTALayout) + return false; + + // Dsmem access is required when srcCTALayout != dstCTALayout + return true; +} + unsigned ReduceOpHelper::getInterWarpSize() { auto srcReduceDimSize = static_cast(srcShape[axis]); unsigned sizeIntraWarps = getIntraWarpSize(); @@ -125,7 +174,7 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() { unsigned bytesPerElem = 0; for (const auto &ty : srcElementTypes) { - bytesPerElem += ty.getIntOrFloatBitWidth() / 8; + bytesPerElem += ceil(ty.getIntOrFloatBitWidth(), 8); } return bytesPerElem * elems; } @@ -136,7 +185,7 @@ bool ReduceOpHelper::isSupportedLayout() { return true; } if (auto mmaLayout = srcLayout.dyn_cast()) { - if (mmaLayout.isAmpere()) { + if (mmaLayout.isAmpere() || mmaLayout.isHopper()) { return true; } } @@ -286,6 +335,8 @@ bool maybeSharedAllocationOp(Operation *op) { return dialect && (dialect->getTypeID() == mlir::TypeID::get() || + dialect->getTypeID() == + mlir::TypeID::get() || dialect->getTypeID() == mlir::TypeID::get() || dialect->getTypeID() == mlir::TypeID::get() || dialect->getTypeID() == mlir::TypeID::get()); @@ -294,6 +345,8 @@ bool maybeSharedAllocationOp(Operation *op) { bool maybeAliasOp(Operation *op) { return isa(op) || isa(op) || isa(op) || + isa(op) || + isa(op) || isa(op); } @@ -303,7 +356,21 @@ bool supportMMA(triton::DotOp op, int version) { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 auto aElemTy = op.getA().getType().cast().getElementType(); auto bElemTy = op.getB().getType().cast().getElementType(); - + if (version == 3) { + if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3")) + return false; + auto retType = op.getResult().getType().cast(); + auto retShapePerCTA = triton::gpu::getShapePerCTA(retType); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + if (!(numWarps % 4 == 0 && retShapePerCTA[0] % 64 == 0 && + retShapePerCTA[1] % 8 == 0 && + (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() || + aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || + aElemTy.isF32()))) { + return false; + } + } if (aElemTy.isF32() && bElemTy.isF32()) { return (op.getAllowTF32() && version == 2) || version == 3; } @@ -311,10 +378,10 @@ bool supportMMA(triton::DotOp op, int version) { } #ifdef USE_ROCM -static bool supportMFMAGranularity(int m, int n, int k) { +static bool supportMFMAGranularity(int m, int n, int k, int64_t nonKDim) { // these limitations are dtype dependent, in future we may relax them - const int granularityMN = 32; - const int granularityK = 8; + const int granularityMN = nonKDim; + const int granularityK = nonKDim == 32 ? 8 : 16; if (m % granularityMN != 0 || n % granularityMN != 0) return false; if (k % granularityK != 0) @@ -322,7 +389,7 @@ static bool supportMFMAGranularity(int m, int n, int k) { return true; } -bool supportMFMA(triton::DotOp op) { +bool supportMFMA(triton::DotOp op, int64_t nonKDim) { auto aTy = op.getA().getType().cast(); auto bTy = op.getB().getType().cast(); @@ -336,7 +403,7 @@ bool supportMFMA(triton::DotOp op) { auto bShape = bTy.getShape(); assert(aShape[1] == bShape[0]); - if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1])) + if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1], nonKDim)) return false; return aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32() || @@ -345,25 +412,22 @@ bool supportMFMA(triton::DotOp op) { #endif bool supportMMA(Value value, int version) { - // Tell whether a DotOp support HMMA by the operand type(either $a or $b). + // Tell whether a DotOp support MMA by the operand type(either $a or $b). // We cannot get both the operand types(in TypeConverter), here we assume the // types of both the operands are identical here. - assert((version == 1 || version == 2) && + assert((version == 1 || version == 2 || version == 3) && "Unexpected MMA layout version found"); auto elemTy = value.getType().cast().getElementType(); - return elemTy.isF16() || elemTy.isBF16() || + // FP8 is not natively supported on all mma versions but it can always be + // promoted to fp16 therefore we can always support it. + bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() || + elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ(); + return isFP8 || elemTy.isF16() || elemTy.isBF16() || (elemTy.isF32() && version >= 2) || (elemTy.isInteger(8) && version >= 2); } -Type getElementType(Value value) { - auto type = value.getType(); - if (auto tensorType = type.dyn_cast()) - return tensorType.getElementType(); - return type; -} - bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { // dot_op = #mma // when #mma = MmaEncoding @@ -389,13 +453,24 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { // layout when opIdx == 1. return mfmaLayout.getWarpsPerCTA()[1] == 1 && dotOperandLayout.getOpIdx() == 0 && - dotOperandLayout.getKWidth() == 8 && + dotOperandLayout.getKWidth() == 4 && dotOperandLayout.getParent() == mfmaLayout && - mfmaLayout.getIsTransposed() && + mfmaLayout.getNonKDim() == 32 && mfmaLayout.getIsTransposed() && (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); } #endif +bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { + auto src = srcTy.getEncoding().cast(); + auto dst = dstTy.getEncoding().cast(); + auto srcElemsPerThread = triton::gpu::getTotalElemsPerThread(srcTy); + auto dstElemsPerThread = triton::gpu::getTotalElemsPerThread(dstTy); + // when #mma = MmaEncoding + return src.getVersionMajor() == 3 && src.getWarpsPerCTA()[1] == 1 && + dst.getVersionMajor() == 3 && dst.getWarpsPerCTA()[1] == 1 && + srcElemsPerThread == dstElemsPerThread; +} + bool isSingleValue(Value value) { // Don't consider load as expensive if it is loading a scalar. if (auto tensorTy = value.getType().dyn_cast()) @@ -455,9 +530,11 @@ struct DFSState { SmallVector topologicalCounts; DenseSet seen; - /// We mark each op as ready if all its operands are seen. If an op is ready, - /// we add it to the queue. Otherwise, we keep adding its operands to the - /// ancestors set. + /// We mark each op as ready if all its operands and parents ops are seen. If + /// an op is ready, we add it to the queue. Otherwise, we keep adding its + /// operands to the ancestors set. + /// We always want an op to be scheduled after all its parents to handle + /// correctly cases with scf operations. void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph, SmallVector &readyQueue) { bool ready = true; @@ -468,6 +545,14 @@ struct DFSState { ready = false; } } + Operation *parent = op->getParentOp(); + while (parent) { + if (!seen.count(parent)) { + subGraph.push_back(parent); + ready = false; + } + parent = parent->getParentOp(); + } if (ready) readyQueue.push_back(op); } @@ -615,4 +700,81 @@ std::unique_ptr createDataFlowSolver() { return solver; } +static triton::MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) { + + if (auto makeTensorPtrOp = dyn_cast(op)) { + return makeTensorPtrOp; + } + + if (auto advanceOp = dyn_cast(op)) { + return getMakeTensorPtrOp(advanceOp.getPtr()); + } + + if (auto branch = dyn_cast(op)) { + auto idx = v.cast().getResultNumber(); + llvm::SmallVector yieldOps; + op->walk([&](Operation *op) { + if (auto yieldOp = dyn_cast(op)) + yieldOps.push_back(yieldOp); + }); + + // benzh@ if multi yields, all yields operand should come from same arg. + Value newValue = yieldOps[0].getOperands()[idx]; + return getMakeTensorPtrOp(newValue); + } + + llvm_unreachable("Unable to getMakeTensorPtr()"); +} + +triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v) { + using BranchOps = llvm::SetVector>; + llvm::DenseMap blockToCFOps; + auto moduleOp = + v.getParentBlock()->getParentOp()->getParentOfType(); + + moduleOp.walk([&](Operation *op) { + if (auto br = dyn_cast(op)) { + Block *block = br.getDest(); + blockToCFOps[block].insert({op, -1}); + } + if (auto condBr = dyn_cast(op)) { + Block *blockT = condBr.getTrueDest(); + Block *blockF = condBr.getFalseDest(); + blockToCFOps[blockT].insert({condBr, 1}); + blockToCFOps[blockF].insert({condBr, 0}); + } + }); + + if (Operation *definingOp = v.getDefiningOp()) { + return getMakeTensorPtrOpImpl(definingOp, v); + } else if (BlockArgument arg = v.cast()) { + unsigned argNum = arg.getArgNumber(); + Operation *argOwner = arg.getOwner()->getParentOp(); + + if (auto forOp = dyn_cast(argOwner)) { + return getMakeTensorPtrOp( + forOp.getOperand(argNum + forOp.getNumControlOperands() - 1)); + } else if (auto funcOp = dyn_cast(argOwner)) { + Block *block = arg.getOwner(); + Operation *op; + int tOrF; + std::tie(op, tOrF) = blockToCFOps[block][0]; + if (auto br = dyn_cast(op)) { + return getMakeTensorPtrOp(br.getDestOperands()[argNum]); + } + if (auto condBr = dyn_cast(op)) { + if (tOrF) { + return getMakeTensorPtrOp(condBr.getTrueDestOperands()[argNum]); + } else { + return getMakeTensorPtrOp(condBr.getFalseDestOperands()[argNum]); + } + } + } else { + return getMakeTensorPtrOp(argOwner->getOperand(argNum)); + } + } + + llvm_unreachable("Unable to getMakeTensorPtr()"); +} + } // namespace mlir diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 143a4375a811..c5dcec8c0e86 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(TritonToTritonGPU) add_subdirectory(TritonGPUToLLVM) +add_subdirectory(NVGPUToLLVM) diff --git a/lib/Conversion/NVGPUToLLVM/CMakeLists.txt b/lib/Conversion/NVGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000000..2f81f67669cd --- /dev/null +++ b/lib/Conversion/NVGPUToLLVM/CMakeLists.txt @@ -0,0 +1,27 @@ +add_mlir_conversion_library(NVGPUToLLVM + NVGPUToLLVMPass.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/triton/Conversion/NVGPUToLLVM + ${PROJECT_BINARY_DIR}/include/triton/Conversion/NVGPUToLLVM + + DEPENDS + NVGPUConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRGPUOps + MLIRGPUToNVVMTransforms + MLIRGPUToROCDLTransforms + MLIRGPUTransforms + TritonAnalysis + TritonIR + TritonGPUIR + TritonGPUTransforms + TritonNvidiaGPUTransforms + NVGPUIR +) diff --git a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp new file mode 100644 index 000000000000..f52256b3ade4 --- /dev/null +++ b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -0,0 +1,1202 @@ +#include "triton/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h" + +#include "../lib/Conversion/TritonGPUToLLVM/Utility.h" +using namespace mlir; +using namespace mlir::triton; + +#define GEN_PASS_CLASSES +#include "triton/Conversion/NVGPUToLLVM/Passes.h.inc" + +namespace ttn = mlir::triton::nvgpu; +using ::mlir::LLVM::getSRegValue; + +namespace { + +template +class NVGPUOpPatternBase : public mlir::RewritePattern { +public: + explicit NVGPUOpPatternBase(mlir::MLIRContext *context) + : mlir::RewritePattern(SourceOp::getOperationName(), 1, context) {} + + LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto loc = op->getLoc(); + auto sourceOp = llvm::dyn_cast(op); + if (!sourceOp) + return mlir::failure(); + auto ptxAsm = static_cast(this)->getPtxAsm(sourceOp); + auto hasSideEffects = !isMemoryEffectFree(sourceOp); + PTXBuilder ptxBuilder; + auto &ptxInstr = *ptxBuilder.create(ptxAsm); + ptxInstr({}, /*onlyAttachMLIRArgs=*/true); + auto asmReturnTy = void_ty(ctx); + ptxBuilder.launch(rewriter, loc, asmReturnTy, + /*hasSideEffects*/ hasSideEffects); + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +class CGABarrierSyncOpPattern + : public NVGPUOpPatternBase { +public: + using Base = + NVGPUOpPatternBase; + using Base::Base; + + std::string getPtxAsm(ttn::CGABarrierSyncOp op) const { + return "barrier.cluster.sync.aligned;"; + } +}; + +class FenceAsyncSharedOpPattern + : public NVGPUOpPatternBase { +public: + using Base = + NVGPUOpPatternBase; + using Base::Base; + + std::string getPtxAsm(ttn::FenceAsyncSharedOp op) const { + auto bCluster = op.getBCluster(); + if (bCluster) + return "fence.proxy.async.shared::cluster;"; + else + return "fence.proxy.async.shared::cta;"; + } +}; + +class WGMMAFenceOpPattern + : public NVGPUOpPatternBase { +public: + using Base = NVGPUOpPatternBase; + using Base::Base; + + std::string getPtxAsm(ttn::WGMMAFenceOp op) const { + return "wgmma.fence.sync.aligned;"; + } +}; + +class WGMMACommitGroupOpPattern + : public NVGPUOpPatternBase { +public: + using Base = + NVGPUOpPatternBase; + using Base::Base; + + std::string getPtxAsm(ttn::WGMMACommitGroupOp op) const { + return "wgmma.commit_group.sync.aligned;"; + } +}; + +class WGMMAWaitGroupOpPattern + : public NVGPUOpPatternBase { +public: + using Base = + NVGPUOpPatternBase; + using Base::Base; + + std::string getPtxAsm(ttn::WGMMAWaitGroupOp op) const { + auto pendings = op.getPendings(); + return "wgmma.wait_group.sync.aligned " + std::to_string(pendings) + ";"; + } +}; + +class StoreMatrixOpPattern : public mlir::RewritePattern { +public: + StoreMatrixOpPattern(mlir::MLIRContext *context) + : mlir::RewritePattern(ttn::StoreMatrixOp::getOperationName(), 1, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto storeMatrixOp = llvm::dyn_cast(op); + if (!storeMatrixOp) + return mlir::failure(); + auto loc = op->getLoc(); + auto addr = storeMatrixOp.getAddr(); + auto datas = storeMatrixOp.getDatas(); + + assert(datas.size() == 1 || datas.size() == 2 || + datas.size() == 4 && "Invalid size for StoreMatrixOp"); + PTXBuilder ptxBuilder; + auto &ptxInstr = *ptxBuilder.create( + "stmatrix.sync.aligned.m8n8.x" + std::to_string(datas.size()) + + ".shared.b16"); + auto *addrOpr = ptxBuilder.newAddrOperand(ptrtoint(i32_ty, addr), "r"); + + SmallVector> args; + for (unsigned i = 0; i < datas.size(); ++i) { + args.push_back({datas[i], "r"}); + } + auto *operands = ptxBuilder.newListOperand(args); + + ptxInstr(addrOpr, operands); + + auto asmReturnTy = void_ty(ctx); + ptxBuilder.launch(rewriter, loc, asmReturnTy); + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +class MBarrierInitOpPattern : public mlir::RewritePattern { +public: + MBarrierInitOpPattern(mlir::MLIRContext *context) + : mlir::RewritePattern(ttn::MBarrierInitOp::getOperationName(), 1, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto mBarrierInitOp = llvm::dyn_cast(op); + if (!mBarrierInitOp) + return mlir::failure(); + auto loc = op->getLoc(); + Value mbarrier = mBarrierInitOp.getMbarrier(); + Value pred = mBarrierInitOp.getPred(); + uint32_t count = mBarrierInitOp.getCount(); + PTXBuilder ptxBuilder; + + auto &ptxInstr = *ptxBuilder.create("mbarrier.init.shared.b64"); + auto *barOpr = + ptxBuilder.newAddrOperand(ptrtoint(i32_ty, mbarrier), "r", 0); + auto *expectedOpr = ptxBuilder.newConstantOperand(count); + + ptxInstr(barOpr, expectedOpr).predicate(pred, "b"); + + auto asmReturnTy = void_ty(ctx); + ptxBuilder.launch(rewriter, loc, asmReturnTy); + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +class MBarrierArriveOpPattern : public mlir::RewritePattern { +public: + MBarrierArriveOpPattern(mlir::MLIRContext *context) + : mlir::RewritePattern(ttn::MBarrierArriveOp::getOperationName(), 1, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto mbarrierArriveOp = llvm::dyn_cast(op); + if (!mbarrierArriveOp) + return mlir::failure(); + auto loc = op->getLoc(); + Value mbarrier = mbarrierArriveOp.getMbarrier(); + Value pred = mbarrierArriveOp.getPred(); + Value ctaId = mbarrierArriveOp.getCtaId(); + auto arriveType = mbarrierArriveOp.getArriveType(); + uint32_t txCount = mbarrierArriveOp.getTxCount(); + + PTXBuilder ptxBuilder; + if (arriveType == ttn::MBarriveType::normal) { + auto &ptxInstr = + *ptxBuilder.create("mbarrier.arrive.shared.b64 _,"); + auto *barOpr = + ptxBuilder.newAddrOperand(ptrtoint(i32_ty, mbarrier), "r", 0); + + ptxInstr(barOpr).predicate(pred, "b"); + } else if (arriveType == ttn::MBarriveType::cp_async) { + auto &ptxInstr = *ptxBuilder.create( + "cp.async.mbarrier.arrive.noinc.shared.b64"); + auto *barOpr = + ptxBuilder.newAddrOperand(ptrtoint(i32_ty, mbarrier), "r", 0); + + ptxInstr(barOpr).predicate(pred, "b"); + } else if (arriveType == ttn::MBarriveType::expect_tx) { + assert(txCount > 0 && "txCount should be valid"); + auto &ptxInstr = *ptxBuilder.create( + "mbarrier.arrive.expect_tx.shared.b64 _,"); + auto *barOpr = + ptxBuilder.newAddrOperand(ptrtoint(i32_ty, mbarrier), "r", 0); + auto *expectedOpr = ptxBuilder.newConstantOperand(txCount); + + ptxInstr(barOpr, expectedOpr).predicate(pred, "b"); + } else if (arriveType == ttn::MBarriveType::remote) { + assert(ctaId && "ctaId should have a valid value"); + auto ptxAsm = + " { .reg .b32 remAddr32; \n" + " @$2 mapa.shared::cluster.u32 remAddr32, $0, $1; \n" + " @$2 mbarrier.arrive.shared::cluster.b64 _, [remAddr32]; } \n"; + auto &ptxInstr = *ptxBuilder.create(ptxAsm); + auto *barOpr = + ptxBuilder.newAddrOperand(ptrtoint(i32_ty, mbarrier), "r", 0); + auto *ctaIdOpr = ptxBuilder.newOperand(ctaId, "r"); + auto *predOpr = ptxBuilder.newOperand(pred, "b"); + + ptxInstr({barOpr, ctaIdOpr, predOpr}, /*onlyAttachMLIRArgs=*/true); + } else { + assert(false && + "Unsupported mbarrier arrive type"); // TODO: is this the right way + // to assert in LLVM pass ? + } + auto asmReturnTy = void_ty(ctx); + ptxBuilder.launch(rewriter, loc, asmReturnTy); + rewriter.eraseOp(op); + return mlir::success(); + } +}; +class MBarrierWaitOpPattern : public mlir::RewritePattern { +public: + MBarrierWaitOpPattern(mlir::MLIRContext *context) + : mlir::RewritePattern(ttn::MBarrierWaitOp::getOperationName(), 1, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto mBarrierWaitOp = llvm::dyn_cast(op); + if (!mBarrierWaitOp) + return mlir::failure(); + auto loc = op->getLoc(); + Value mbarrier = mBarrierWaitOp.getMbarrier(); + Value phase = mBarrierWaitOp.getPhase(); + PTXBuilder ptxBuilder; + + auto ptxAsm = + "{\n" + ".reg .pred P1; \n" + "LAB_WAIT: \n" + "mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, 0x989680; \n" + "@P1 bra.uni DONE; \n" + "bra.uni LAB_WAIT; \n" + "DONE: \n" + "}"; + auto &ptxInstr = *ptxBuilder.create(ptxAsm); + auto *barOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, mbarrier), "r"); + auto *phaseOpr = ptxBuilder.newOperand(zext(i32_ty, phase), "r"); + + ptxInstr({barOpr, phaseOpr}, + /*onlyAttachMLIRArgs=*/true); + + auto asmReturnTy = void_ty(ctx); + ptxBuilder.launch(rewriter, loc, asmReturnTy); + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +class ClusterArriveOpPattern + : public NVGPUOpPatternBase { +public: + using Base = NVGPUOpPatternBase; + using Base::Base; + + std::string getPtxAsm(ttn::ClusterArriveOp op) const { + auto relaxed = op.getRelaxed(); + if (relaxed) + return "barrier.cluster.arrive.relaxed.aligned;"; + else + return "barrier.cluster.arrive.aligned;"; + } +}; + +class ClusterWaitOpPattern + : public NVGPUOpPatternBase { +public: + using Base = NVGPUOpPatternBase; + using Base::Base; + std::string getPtxAsm(ttn::ClusterWaitOp op) const { + return "barrier.cluster.wait.aligned;"; + } +}; + +class TMALoadTiledOpPattern : public mlir::RewritePattern { +public: + TMALoadTiledOpPattern(mlir::MLIRContext *context) + : mlir::RewritePattern(ttn::TMALoadTiledOp::getOperationName(), 1, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto tmaLoadTiledOp = llvm::dyn_cast(op); + if (!tmaLoadTiledOp) + return mlir::failure(); + auto loc = op->getLoc(); + auto dst = tmaLoadTiledOp.getDst(); + auto mbarrier = tmaLoadTiledOp.getMbarrier(); + auto tmaDesc = tmaLoadTiledOp.getTmaDesc(); + auto l2Desc = tmaLoadTiledOp.getL2Desc(); + auto pred = tmaLoadTiledOp.getPred(); + auto coords = tmaLoadTiledOp.getCoords(); + auto mcastMask = tmaLoadTiledOp.getMcastMask(); + + auto dimSize = coords.size(); + + PTXBuilder ptxBuilder; + if (dimSize == 2) { + if (mcastMask == nullptr) { + auto ptxAsm = + "@$6 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier:" + ":complete_tx" + "::bytes.L2::cache_hint [$0], [$1, {$2, $3}], [$4], $5;"; + auto &ptxInstr = *ptxBuilder.create(ptxAsm); + auto *dstOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, dst), "r"); + auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); + auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); + auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); + auto *barOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, mbarrier), "r"); + auto *l2DescOpr = ptxBuilder.newOperand(l2Desc, "l"); + auto *predOpr = ptxBuilder.newOperand(pred, "b"); + + ptxInstr({dstOpr, descOpr, c0Opr, c1Opr, barOpr, l2DescOpr, predOpr}, + /*onlyAttachMLIRArgs=*/true); + } else { + auto ptxAsm = + "@$7 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [$0], [$1, {$2, $3}], [$4], $5, $6;"; + auto &ptxInstr = *ptxBuilder.create(ptxAsm); + auto *dstOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, dst), "r"); + auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); + auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); + auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); + auto *barOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, mbarrier), "r"); + auto *maskOpr = ptxBuilder.newOperand(mcastMask, "h"); + auto *l2DescOpr = ptxBuilder.newOperand(l2Desc, "l"); + auto *predOpr = ptxBuilder.newOperand(pred, "b"); + ptxInstr({dstOpr, descOpr, c0Opr, c1Opr, barOpr, maskOpr, l2DescOpr, + predOpr}, + /*onlyAttachMLIRArgs=*/true); + } + } else if (dimSize == 4) { + assert(mcastMask == nullptr && "Does not support multicast"); + auto ptxAsm = "@$8 " + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier:" + ":complete_tx" + "::bytes.L2::cache_hint [$0], [$1, {$2, $3, $4, $5}], " + "[$6], $7;"; + auto &ptxInstr = *ptxBuilder.create(ptxAsm); + auto *dstOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, dst), "r"); + auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); + auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); + auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); + auto *c2Opr = ptxBuilder.newOperand(coords[2], "r"); + auto *c3Opr = ptxBuilder.newOperand(coords[3], "r"); + auto *barOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, mbarrier), "r"); + auto *l2DescOpr = ptxBuilder.newOperand(l2Desc, "l"); + auto *predOpr = ptxBuilder.newOperand(pred, "b"); + ptxInstr({dstOpr, descOpr, c0Opr, c1Opr, c2Opr, c3Opr, barOpr, l2DescOpr, + predOpr}, + /*onlyAttachMLIRArgs=*/true); + } else { + assert(false && "invalid dim size"); + } + + auto asmReturnTy = void_ty(ctx); + ptxBuilder.launch(rewriter, loc, asmReturnTy, /*hasSideEffect*/ true); + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +class TMAStoreTiledOpPattern : public mlir::RewritePattern { +public: + TMAStoreTiledOpPattern(mlir::MLIRContext *context) + : mlir::RewritePattern(ttn::TMAStoreTiledOp::getOperationName(), 1, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto tmaStoreTiledOp = llvm::dyn_cast(op); + if (!tmaStoreTiledOp) + return mlir::failure(); + auto loc = op->getLoc(); + auto src = tmaStoreTiledOp.getSrc(); + auto tmaDesc = tmaStoreTiledOp.getTmaDesc(); + auto pred = tmaStoreTiledOp.getPred(); + auto coords = tmaStoreTiledOp.getCoords(); + + auto dimSize = coords.size(); + + PTXBuilder ptxBuilder; + if (dimSize == 2) { + auto ptxAsm = "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group" + "[$0, {$2, $3}], [$1];"; + auto &ptxInstr = *ptxBuilder.create(ptxAsm); + + auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); + auto *srcOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, src), "r"); + auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); + auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); + auto *predOpr = ptxBuilder.newOperand(pred, "b"); + ptxInstr({descOpr, srcOpr, c0Opr, c1Opr, predOpr}, + /*onlyAttachMLIRArgs=*/true); + } else if (dimSize == 3) { + auto ptxAsm = "@$5 cp.async.bulk.tensor.3d.global.shared::cta.bulk_group" + "[$0, {$2, $3, $4}], [$1];"; + auto &ptxInstr = *ptxBuilder.create(ptxAsm); + + auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); + auto *srcOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, src), "r"); + auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); + auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); + auto *c2Opr = ptxBuilder.newOperand(coords[2], "r"); + auto *predOpr = ptxBuilder.newOperand(pred, "b"); + ptxInstr({descOpr, srcOpr, c0Opr, c1Opr, c2Opr, predOpr}, + /*onlyAttachMLIRArgs=*/true); + } else if (dimSize == 4) { + auto ptxAsm = "@$6 cp.async.bulk.tensor.4d.global.shared::cta.bulk_group" + "[$0, {$2, $3, $4, $5}], [$1];"; + auto &ptxInstr = *ptxBuilder.create(ptxAsm); + auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); + auto *srcOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, src), "r"); + auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); + auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); + auto *c2Opr = ptxBuilder.newOperand(coords[2], "r"); + auto *c3Opr = ptxBuilder.newOperand(coords[3], "r"); + auto *predOpr = ptxBuilder.newOperand(pred, "b"); + ptxInstr({descOpr, srcOpr, c0Opr, c1Opr, c2Opr, c3Opr, predOpr}, + /*onlyAttachMLIRArgs=*/true); + } else { + assert(false && "invalid dim size"); + } + + auto asmReturnTy = void_ty(ctx); + ptxBuilder.launch(rewriter, loc, asmReturnTy, /*hasSideEffect*/ true); + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +class LoadDSmemOpPattern : public mlir::RewritePattern { +public: + LoadDSmemOpPattern(mlir::MLIRContext *context) + : mlir::RewritePattern(ttn::LoadDSmemOp::getOperationName(), 1, context) { + } + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto loadDSmemOp = llvm::dyn_cast(op); + if (!loadDSmemOp) + return mlir::failure(); + auto loc = op->getLoc(); + auto addr = loadDSmemOp.getAddr(); + auto ctaId = loadDSmemOp.getCtaId(); + auto bitwidth = loadDSmemOp.getBitwidth(); + auto vec = loadDSmemOp.getVec(); + + assert( + (bitwidth == 8 || bitwidth == 16 || bitwidth == 32 || bitwidth == 64) && + "invalid bitwidth"); + assert((vec == 1 || vec == 2 || vec == 4) && "invalid vec size"); + PTXBuilder ptxBuilder; + + std::string o1 = vec > 1 ? ".v.u" : ".u"; + std::string vecStr = vec == 1 ? "$0" + : vec == 2 ? "{$0, $1}" + : "{$0, $1, $2, $3}"; + unsigned argNum = vec == 1 ? 1 : vec == 2 ? 2 : 4; + auto ptxAsm = "{\n" + ".reg .u32 remoteAddr;\n" + "mapa.shared::cluster.u32 remoteAddr, $" + + std::to_string(argNum) + " , $" + std::to_string(argNum + 1) + + " ; \n" + "ld.shared::cluster" + + o1 + std::to_string(bitwidth) + " " + vecStr + + ", [remoteAddr];\n" + "}\n"; + + auto &ptxInstr = *ptxBuilder.create(ptxAsm); + std::string c = bitwidth == 16 ? "=h" : (bitwidth == 32 ? "=r" : "=l"); + SmallVector oprs; + for (unsigned i = 0; i < vec; ++i) { + auto *ret = ptxBuilder.newOperand(c); + oprs.push_back(ret); + } + auto *addrOpr = ptxBuilder.newOperand(addr, "r"); + auto *ctaIdOpr = ptxBuilder.newOperand(ctaId, "r"); + oprs.push_back(addrOpr); + oprs.push_back(ctaIdOpr); + + Type retTy = IntegerType::get(rewriter.getContext(), bitwidth); + SmallVector retTys(vec, retTy); + if (vec > 1) + retTy = struct_ty(retTys); + + ptxInstr(oprs, + /*onlyAttachMLIRArgs=*/true); + + auto res = ptxBuilder.launch(rewriter, loc, retTy); + rewriter.replaceOp(op, {res}); + return mlir::success(); + } +}; + +class WGMMAOpPattern : public mlir::RewritePattern { +public: + WGMMAOpPattern(mlir::MLIRContext *context) + : mlir::RewritePattern(ttn::WGMMAOp::getOperationName(), 1, context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + using namespace ttn; + auto ctx = rewriter.getContext(); + auto wgmmaOp = llvm::dyn_cast(op); + if (!wgmmaOp) + return mlir::failure(); + auto loc = op->getLoc(); + auto opA = wgmmaOp.getOpA(); + auto opB = wgmmaOp.getOpB(); + auto opC = wgmmaOp.getOpC(); + auto m = wgmmaOp.getM(); + auto n = wgmmaOp.getN(); + auto k = wgmmaOp.getK(); + auto eltTypeC = wgmmaOp.getEltTypeC(); + auto eltTypeA = wgmmaOp.getEltTypeA(); + auto eltTypeB = wgmmaOp.getEltTypeB(); + auto layoutA = wgmmaOp.getLayoutA(); + auto layoutB = wgmmaOp.getLayoutB(); + + // Register checks + auto typeA = opA.getType(); + auto typeB = opB.getType(); + auto typeC = opC.getType(); + auto structTypeA = typeA.dyn_cast(); + auto structTypeB = typeB.dyn_cast(); + auto structTypeC = typeC.dyn_cast(); + assert(!structTypeB && "Operand B can not be registers"); + assert(structTypeC && "Operand C must be registers"); + + // Element type, MNK shape and transposing support check + // Reference: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-mma + bool transA = layoutA == WGMMALayout::col; + bool transB = layoutB == WGMMALayout::row; + bool supported = false, needTransArgs = false, floatTypeWGMMA = false; + assert(m % 8 == 0 && n % 8 == 0 && k % 8 == 0); + // Below instructions do support transposing, must pass `trans` arguments + supported |= + (eltTypeA == WGMMAEltType::f16) && (eltTypeB == WGMMAEltType::f16) && + (eltTypeC == WGMMAEltType::f16 || eltTypeC == WGMMAEltType::f32) && + (m == 64 && 8 <= n && n <= 256 && k == 16); + supported |= (eltTypeA == WGMMAEltType::bf16) && + (eltTypeB == WGMMAEltType::bf16) && + (eltTypeC == WGMMAEltType::f32) && + (m == 64 && 8 <= n && n <= 256 && k == 16); + needTransArgs = supported; + floatTypeWGMMA = supported; + // Below instructions do not support transposing + if (!supported && !transA && !transB) { + supported |= (eltTypeA == WGMMAEltType::tf32) && + (eltTypeB == WGMMAEltType::tf32) && + (eltTypeC == WGMMAEltType::f32) && + (m == 64 && 8 <= n && n <= 256 && k == 8); + supported |= + (eltTypeA == WGMMAEltType::e4m3 || eltTypeA == WGMMAEltType::e5m2) && + (eltTypeB == WGMMAEltType::e4m3 || eltTypeB == WGMMAEltType::e5m2) && + (eltTypeC == WGMMAEltType::f16 || eltTypeC == WGMMAEltType::f32) && + (m == 64 && 8 <= n && n <= 256 && k == 32); + floatTypeWGMMA = supported; + // Below instructions are integer-based + supported |= (eltTypeA == WGMMAEltType::s8) && + (eltTypeB == WGMMAEltType::s8) && + (eltTypeC == WGMMAEltType::s32) && + (m == 64 && 8 <= n && n <= 224 && k == 32); + } + assert(supported && "WGMMA type or shape is not supported"); + PTXBuilder ptxBuilder; + SmallVector oprs; + + // Operands + uint32_t asmOpIdx = 0; + + // Operand C + uint32_t numCRegs = structTypeC.getBody().size(); + + std::string args = ""; + args += "{"; + for (uint32_t i = 0; i < numCRegs; ++i) { + args += "$" + std::to_string(asmOpIdx++) + (i == numCRegs - 1 ? "" : ","); + // LLVM does not support `+` semantic, we must repeat the arguments for + // both input and outputs + PTXBuilder::Operand *opr; + if (structTypeC.getBody().front().isF32()) + opr = ptxBuilder.newOperand( + extract_val(structTypeC.getBody()[i], opC, i), "=f"); + else + opr = ptxBuilder.newOperand( + extract_val(structTypeC.getBody()[i], opC, i), "=r"); + oprs.push_back(opr); + } + args += "}, "; + + for (uint32_t i = asmOpIdx - numCRegs; i < asmOpIdx; ++i) { + auto *opr = ptxBuilder.newOperand(i); + oprs.push_back(opr); + } + + // Note that LLVM will not skip the indexed repeating placeholders + asmOpIdx += numCRegs; + // Operand A + if (structTypeA) { + uint32_t numARegs = m * k / 128; + assert(numARegs == structTypeA.getBody().size()); + args += "{"; + for (uint32_t i = 0; i < numARegs; ++i) { + args += + "$" + std::to_string(asmOpIdx++) + (i == numARegs - 1 ? "" : ","); + auto *opr = ptxBuilder.newOperand( + extract_val(structTypeA.getBody()[i], opA, i), "f"); + oprs.push_back(opr); + } + args += "}, "; + } else { + args += "$" + std::to_string(asmOpIdx++) + ", "; + auto *opr = ptxBuilder.newOperand(opA, "l"); + oprs.push_back(opr); + } + + // Operand B (must be `desc`) + args += "$" + std::to_string(asmOpIdx++) + ", "; + auto *opr = ptxBuilder.newOperand(opB, "l"); + oprs.push_back(opr); + + // `scale-d` is 1 by default + args += "1"; + + // `imm-scale-a`, and `imm-scale-b` are 1 by default only for float-based + // WGMMA + if (floatTypeWGMMA) + args += ", 1, 1"; + + // Push `trans-a` and `trans-b` args if needed (determined as constant) + if (needTransArgs) + args += ", " + std::to_string(transA) + ", " + std::to_string(transB); + + auto ptxAsm = "wgmma.mma_async.sync.aligned" + ".m" + + std::to_string(m) + "n" + std::to_string(n) + "k" + + std::to_string(k) + "." + stringifyEnum(eltTypeC).str() + + "." + stringifyEnum(eltTypeA).str() + "." + + stringifyEnum(eltTypeB).str() + " " + args + ";"; + + auto &ptxInstr = *ptxBuilder.create(ptxAsm); + ptxInstr(oprs, + /*onlyAttachMLIRArgs=*/true); + + auto res = + ptxBuilder.launch(rewriter, loc, structTypeC, /*hasSideEffect*/ true); + rewriter.replaceOp(op, {res}); + return mlir::success(); + } +}; + +class FenceMBarrierInitOpPattern + : public NVGPUOpPatternBase { +public: + using Base = + NVGPUOpPatternBase; + using Base::Base; + + std::string getPtxAsm(ttn::FenceMBarrierInitOp op) const { + return "fence.mbarrier_init.release.cluster;"; + } +}; + +class NamedBarrierArriveOpPattern : public mlir::RewritePattern { +public: + NamedBarrierArriveOpPattern(mlir::MLIRContext *context) + : mlir::RewritePattern(ttn::NamedBarrierArriveOp::getOperationName(), 1, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto namedBarrierArriveOp = llvm::dyn_cast(op); + if (!namedBarrierArriveOp) + return mlir::failure(); + auto loc = op->getLoc(); + auto bar = namedBarrierArriveOp.getBar(); + auto numThreads = namedBarrierArriveOp.getNumThreads(); + PTXBuilder ptxBuilder; + + auto &ptxInstr = *ptxBuilder.create("bar.arrive $0, $1;"); + auto *barOpr = ptxBuilder.newOperand(bar, "r"); + auto *numThreadsOpr = ptxBuilder.newOperand(numThreads, "r"); + ptxInstr({barOpr, numThreadsOpr}, /*onlyAttachMLIRArgs=*/true); + + auto asmReturnTy = void_ty(ctx); + ptxBuilder.launch(rewriter, loc, asmReturnTy); + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +class NamedBarrierWaitOpPattern : public mlir::RewritePattern { +public: + NamedBarrierWaitOpPattern(mlir::MLIRContext *context) + : mlir::RewritePattern(ttn::NamedBarrierWaitOp::getOperationName(), 1, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto namedBarrierWaitOp = llvm::dyn_cast(op); + if (!namedBarrierWaitOp) + return mlir::failure(); + auto loc = op->getLoc(); + auto bar = namedBarrierWaitOp.getBar(); + auto numThreads = namedBarrierWaitOp.getNumThreads(); + PTXBuilder ptxBuilder; + + auto &ptxInstr = *ptxBuilder.create("bar.sync $0, $1;"); + auto *barOpr = ptxBuilder.newOperand(bar, "r"); + auto *numThreadsOpr = ptxBuilder.newOperand(numThreads, "r"); + ptxInstr({barOpr, numThreadsOpr}, /*onlyAttachMLIRArgs=*/true); + + auto asmReturnTy = void_ty(ctx); + ptxBuilder.launch(rewriter, loc, asmReturnTy); + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +class CGABarrierArriveOpPattern + : public NVGPUOpPatternBase { +public: + using Base = + NVGPUOpPatternBase; + using Base::Base; + std::string getPtxAsm(ttn::CGABarrierArriveOp op) const { + return "barrier.cluster.arrive;"; + } +}; + +class CGABarrierWaitOpPattern + : public NVGPUOpPatternBase { +public: + using Base = + NVGPUOpPatternBase; + using Base::Base; + std::string getPtxAsm(ttn::CGABarrierWaitOp op) const { + return "barrier.cluster.wait;"; + } +}; + +class StoreDSmemOpPattern : public mlir::RewritePattern { +public: + StoreDSmemOpPattern(mlir::MLIRContext *context) + : mlir::RewritePattern(ttn::StoreDSmemOp::getOperationName(), 1, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto storeDSmemOp = llvm::dyn_cast(op); + if (!storeDSmemOp) + return mlir::failure(); + auto loc = op->getLoc(); + auto addr = storeDSmemOp.getAddr(); + auto ctaId = storeDSmemOp.getCtaId(); + auto values = storeDSmemOp.getValues(); + auto pred = storeDSmemOp.getPred(); + + auto bitwidth = storeDSmemOp.getBitwidth(); + auto vec = storeDSmemOp.getVec(); + assert( + (bitwidth == 8 || bitwidth == 16 || bitwidth == 32 || bitwidth == 64) && + "invalid bitwidth"); + assert((vec == 1 || vec == 2 || vec == 4) && vec == values.size() && + "invalid vec size"); + + PTXBuilder ptxBuilder; + + std::string ptxAsm = "{\n\t" + ".reg .u32 remoteAddr;\n\t" + "mapa.shared::cluster.u32 remoteAddr, $0, $1;\n\t" + ".reg .pred p;\n\t" + "mov.pred p, $2;\n\t" + "@p st.shared::cluster"; + if (vec > 1) + ptxAsm += ".v" + std::to_string(vec); + ptxAsm += ".u" + std::to_string(bitwidth) + " [remoteAddr], "; + if (vec == 1) + ptxAsm += "$3"; + else if (vec == 2) + ptxAsm += "{$3, $4}"; + else if (vec == 4) + ptxAsm += "{$3, $4, $5, $6}"; + ptxAsm += ";\n\t"; + ptxAsm += "}\n"; + auto &ptxInstr = *ptxBuilder.create(ptxAsm); + + std::string c = bitwidth == 16 ? "h" : (bitwidth == 32 ? "r" : "l"); + SmallVector oprs; + auto *addrOpr = ptxBuilder.newOperand(addr, "r"); + oprs.push_back(addrOpr); + auto *ctaIdOpr = ptxBuilder.newOperand(ctaId, "r"); + oprs.push_back(ctaIdOpr); + auto *predOpr = ptxBuilder.newOperand(pred, "b"); + oprs.push_back(predOpr); + for (unsigned i = 0; i < values.size(); i++) { + auto *valueOpr = ptxBuilder.newOperand(values[i], c); + oprs.push_back(valueOpr); + } + ptxInstr(oprs, + /*onlyAttachMLIRArgs=*/true); + + auto asmReturnTy = void_ty(ctx); + ptxBuilder.launch(rewriter, loc, asmReturnTy, /*hasSideEffect*/ true); + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +class Sts64OpPattern : public mlir::RewritePattern { +public: + Sts64OpPattern(mlir::MLIRContext *context) + : mlir::RewritePattern(ttn::Sts64Op::getOperationName(), 1, context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto sts64Op = llvm::dyn_cast(op); + if (!sts64Op) + return mlir::failure(); + auto loc = op->getLoc(); + auto offset = sts64Op.getOffset(); + auto d0 = sts64Op.getD0(); + auto d1 = sts64Op.getD1(); + + PTXBuilder ptxBuilder; + + std::string ptxAsm = "st.shared.v2.b32 [$0], {$1, $2}"; + auto &ptxInstr = *ptxBuilder.create(ptxAsm); + + SmallVector oprs; + auto *addrOpr = ptxBuilder.newOperand(offset, "r"); + auto *d0Opr = ptxBuilder.newOperand(d0, "r"); + auto *d1Opr = ptxBuilder.newOperand(d1, "r"); + + ptxInstr({addrOpr, d0Opr, d1Opr}, + /*onlyAttachMLIRArgs=*/true); + + auto asmReturnTy = void_ty(ctx); + ptxBuilder.launch(rewriter, loc, asmReturnTy); + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +class RegAllocOpPattern + : public NVGPUOpPatternBase { +public: + using Base = NVGPUOpPatternBase; + using Base::Base; + + std::string getPtxAsm(ttn::RegAllocOp op) const { + auto regCount = op.getRegCount(); + return "setmaxnreg.inc.sync.aligned.u32 " + std::to_string(regCount) + ";"; + } +}; + +class RegDeallocOpPattern + : public NVGPUOpPatternBase { +public: + using Base = NVGPUOpPatternBase; + using Base::Base; + + std::string getPtxAsm(ttn::RegDeallocOp op) const { + auto regCount = op.getRegCount(); + return "setmaxnreg.dec.sync.aligned.u32 " + std::to_string(regCount) + ";"; + } +}; + +class ClusterCTAIdOpPattern : public mlir::RewritePattern { +public: + ClusterCTAIdOpPattern(mlir::MLIRContext *context) + : mlir::RewritePattern(ttn::ClusterCTAIdOp::getOperationName(), 1, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto clusterCTAIdOp = llvm::dyn_cast(op); + if (!clusterCTAIdOp) + return mlir::failure(); + auto loc = op->getLoc(); + + auto x = getSRegValue(rewriter, loc, "%cluster_ctaid.x"); + auto y = getSRegValue(rewriter, loc, "%cluster_ctaid.y"); + auto z = getSRegValue(rewriter, loc, "%cluster_ctaid.z"); + auto nx = getSRegValue(rewriter, loc, "%cluster_nctaid.x"); + auto ny = getSRegValue(rewriter, loc, "%cluster_nctaid.y"); + auto res = add(x, mul(add(y, mul(z, ny)), nx)); + rewriter.replaceOp(op, {res}); + return mlir::success(); + } +}; + +class WGMMADescCreateOpPattern : public mlir::RewritePattern { +public: + WGMMADescCreateOpPattern(mlir::MLIRContext *context) + : mlir::RewritePattern(ttn::WGMMADescCreateOp::getOperationName(), 1, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto wgmmaDescCreateOp = llvm::dyn_cast(op); + if (!wgmmaDescCreateOp) + return mlir::failure(); + auto loc = op->getLoc(); + auto buffer = wgmmaDescCreateOp.getBuffer(); + auto height = wgmmaDescCreateOp.getHeight(); + uint32_t mode = static_cast(wgmmaDescCreateOp.getMode()); + + auto smem_nvvm_pointer = ptrtoint(i64_ty, buffer); + + Value desc = int_val(64, 0); + uint64_t swizzling = (mode == 1 ? 128 : mode == 2 ? 64 : 32); + Value swizzling_ = int_val(64, swizzling); + Value smem_address_bit = smem_nvvm_pointer; + + Value strideDimension = + lshr(shl(swizzling_, int_val(64, 3)), int_val(64, 4)); + Value height64 = zext(i64_ty, height); + Value leadingDimension = lshr(mul(height64, swizzling_), int_val(64, 4)); + + // Value baseOffset = int_val(64, 0); + Value startAddr = + lshr(shl(smem_address_bit, int_val(64, 46)), int_val(64, 50)); + + Value mode_ = int_val(64, mode); + desc = or_(desc, shl(mode_, int_val(64, 62))); + desc = or_(desc, shl(strideDimension, int_val(64, 32))); + desc = or_(desc, shl(leadingDimension, int_val(64, 16))); + // desc = or_(desc, shl(baseOffset, int_val(64, 49))); + desc = or_(desc, startAddr); + + rewriter.replaceOp(op, {desc}); + return mlir::success(); + } +}; + +class OffsetOfSts64OpPattern : public mlir::RewritePattern { +public: + OffsetOfSts64OpPattern(mlir::MLIRContext *context) + : mlir::RewritePattern(ttn::OffsetOfSts64Op::getOperationName(), 1, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto offsetOfSts64Op = llvm::dyn_cast(op); + if (!offsetOfSts64Op) + return mlir::failure(); + auto loc = op->getLoc(); + auto threadId = offsetOfSts64Op.getThreadId(); + auto rowOfWarp = offsetOfSts64Op.getRowOfWarp(); + auto elemIdx = offsetOfSts64Op.getElemIdx(); + auto leadingDimOffset = offsetOfSts64Op.getLeadingDimOffset(); + auto rowStride = offsetOfSts64Op.getRowStride(); + auto swizzleEnabled = offsetOfSts64Op.getSwizzleEnabled(); + + if (swizzleEnabled) { + assert((rowStride == 32 || rowStride == 64 || rowStride == 128) && + "wrong rowString for swizzleEnabled"); + } + + uint32_t perPhase = 0; + uint32_t maxPhase = 0; + if (rowStride == 128) { + perPhase = 1; + maxPhase = 8; + } else if (rowStride == 64) { + perPhase = 2; + maxPhase = 4; + } else if (rowStride == 32) { + perPhase = 4; + maxPhase = 2; + } + + auto laneId = and_(threadId, i32_val(0x1f)); + auto myRow = + add(mul(and_(lshr(elemIdx, i32_val(1)), i32_val(0x1)), i32_val(8)), + udiv(laneId, i32_val(4))); + auto myCol = add(mul(udiv(elemIdx, i32_val(4)), i32_val(8)), + mul(urem(laneId, i32_val(4)), i32_val(2))); + myRow = add(myRow, rowOfWarp); + auto phase = urem(udiv(myRow, i32_val(perPhase)), i32_val(maxPhase)); + auto lineOffset = + add(mul(urem(myRow, i32_val(perPhase)), i32_val(rowStride)), + mul(myCol, i32_val(4))); + auto colOffset = + add(mul(xor_(udiv(lineOffset, i32_val(16)), phase), i32_val(16)), + urem(lineOffset, i32_val(16))); + auto offset = + add(mul(udiv(myRow, i32_val(perPhase)), i32_val(128)), colOffset); + + rewriter.replaceOp(op, {offset}); + return mlir::success(); + } +}; + +class OffsetOfStmatrixV4OpPattern : public mlir::RewritePattern { +public: + OffsetOfStmatrixV4OpPattern(mlir::MLIRContext *context) + : mlir::RewritePattern(ttn::OffsetOfStmatrixV4Op::getOperationName(), 1, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto offsetOfStmatrixV4Op = llvm::dyn_cast(op); + if (!offsetOfStmatrixV4Op) + return mlir::failure(); + auto loc = op->getLoc(); + auto threadId = offsetOfStmatrixV4Op.getThreadId(); + auto rowOfWarp = offsetOfStmatrixV4Op.getRowOfWarp(); + auto elemIdx = offsetOfStmatrixV4Op.getElemIdx(); + auto leadingDimOffset = offsetOfStmatrixV4Op.getLeadingDimOffset(); + auto rowStride = offsetOfStmatrixV4Op.getRowStride(); + auto swizzleEnabled = offsetOfStmatrixV4Op.getSwizzleEnabled(); + + if (swizzleEnabled) { + uint32_t perPhase = 0; + uint32_t maxPhase = 0; + if (rowStride == 64) { + perPhase = 1; + maxPhase = 8; + } else if (rowStride == 32) { + perPhase = 2; + maxPhase = 4; + } else if (rowStride == 16) { + perPhase = 4; + maxPhase = 2; + } + + Value iterOfCol = udiv(elemIdx, i32_val(8)); + Value myRow = add(rowOfWarp, and_(threadId, i32_val(0xf))); + Value myCol = + mul(and_(lshr(threadId, i32_val(4)), i32_val(0x1)), i32_val(8)); + myCol = add(myCol, mul(iterOfCol, i32_val(16))); + + Value offset0 = + mul(udiv(myCol, i32_val(rowStride)), i32_val(leadingDimOffset)); + myCol = urem(myCol, i32_val(rowStride)); + + Value phase = urem(udiv(myRow, i32_val(perPhase)), i32_val(maxPhase)); + + Value lineOffset = + add(mul(urem(myRow, i32_val(perPhase)), i32_val(rowStride)), myCol); + Value colOffset = + add(mul(xor_(udiv(lineOffset, i32_val(8)), phase), i32_val(8)), + urem(lineOffset, i32_val(8))); + Value offset1 = + add(mul(udiv(myRow, i32_val(perPhase)), i32_val(64)), colOffset); + + Value res = add(offset1, offset0); + + rewriter.replaceOp(op, {res}); + } else { + Value iterOfCol = udiv(elemIdx, i32_val(4)); + Value myRow = add(rowOfWarp, and_(threadId, i32_val(0xf))); + Value myCol = + mul(and_(lshr(threadId, i32_val(4)), i32_val(0x1)), i32_val(8)); + myCol = add(myCol, mul(iterOfCol, i32_val(16))); + + Value offset = + add(mul(myRow, i32_val(rowStride)), mul(myCol, i32_val(2))); + rewriter.replaceOp(op, {offset}); + } + return mlir::success(); + } +}; + +class ConvertNVGPUToLLVM : public ConvertNVGPUToLLVMBase { + +public: + explicit ConvertNVGPUToLLVM() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + RewritePatternSet patterns(context); + + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { + +std::unique_ptr> createConvertNVGPUToLLVMPass() { + return std::make_unique<::ConvertNVGPUToLLVM>(); +} + +} // namespace triton +} // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp new file mode 100644 index 000000000000..ca0970a7c9e4 --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp @@ -0,0 +1,217 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "BarrierOpToLLVM.h" + +using namespace mlir; +using namespace mlir::triton; + +// -------------------------------------------------------------------------- +// -- MBarrier related Ops lowering, to be moved to a seperate file --------- +// -------------------------------------------------------------------------- +struct AllocMBarrierOpConversion : public ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::AllocMBarrierOp> { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::AllocMBarrierOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::AllocMBarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Value smemBase = getSharedMemoryBase(loc, rewriter, op.getResult()); + auto resultTy = op.getType(); + auto resultTensorTy = resultTy.dyn_cast(); + Type elemPtrTy; + if (resultTensorTy) { + auto llvmElemTy = + getTypeConverter()->convertType(resultTensorTy.getElementType()); + elemPtrTy = ptr_ty(llvmElemTy, 3); + } else { + elemPtrTy = getTypeConverter()->convertType(resultTy); + } + smemBase = bitcast(smemBase, elemPtrTy); + auto threadId = getThreadId(rewriter, loc); + auto pred = icmp_eq(threadId, i32_val(0)); + int numMBarriers = 1; + if (resultTensorTy) { + assert(resultTensorTy.getRank() == 1 && + "unexpected rank for AllocMBarrierOp"); + numMBarriers = resultTensorTy.getShape()[0]; + } + for (int i = 0; i < numMBarriers; ++i) { + Value smem = smemBase; + if (i > 0) { + smem = gep(elemPtrTy, smem, i32_val(i)); + } + rewriter.create(loc, smem, pred, + op.getCount()); + } + if (resultTensorTy) { + auto smemObj = SharedMemoryObject(smemBase, resultTensorTy.getShape(), + {0}, loc, rewriter); + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + } else { + rewriter.replaceOp(op, smemBase); + } + return success(); + } +}; + +struct MBarrierArriveOpConversion : public ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::MBarrierArriveOp> { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::MBarrierArriveOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::MBarrierArriveOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto mbarrier = adaptor.getMbarrier(); + bool trackAsyncOp = op.getTrackAsyncOp(); + triton::nvgpu::MBarriveType type = triton::nvgpu::MBarriveType::normal; + uint32_t txCount = op.getTxCount(); + auto remoteCtaId = adaptor.getRemoteCtaId(); + if (trackAsyncOp) { + type = triton::nvgpu::MBarriveType::cp_async; + } else if (remoteCtaId) { + assert(txCount == 0 && + "remote arrive of transaction mbarrier is not implemented yet"); + type = triton::nvgpu::MBarriveType::remote; + } else if (txCount > 0) { + type = triton::nvgpu::MBarriveType::expect_tx; + } + Value pred = adaptor.getPred(); + if (pred == nullptr) { + pred = int_val(/*width*/ 1, 1); + } + rewriter.replaceOpWithNewOp( + op, mbarrier, pred, remoteCtaId, type, txCount); + return success(); + } +}; + +struct MBarrierWaitOpConversion : public ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::MBarrierWaitOp> { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::MBarrierWaitOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::MBarrierWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getMbarrier(), adaptor.getPhase()); + return success(); + } +}; + +struct ExtractMBarrierOpConversion + : public ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::ExtractMBarrierOp> { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::ExtractMBarrierOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::ExtractMBarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto elemTy = + op.getTensor().getType().cast().getElementType(); + auto tensorStruct = adaptor.getTensor(); + auto index = adaptor.getIndex(); + auto ptrTy = + LLVM::LLVMPointerType::get(getTypeConverter()->convertType(elemTy), 3); + auto basePtr = + extract_val(ptrTy, tensorStruct, rewriter.getDenseI64ArrayAttr(0)); + Value result = gep(ptrTy, basePtr, index); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct NamedBarrierArriveOpConversion + : public ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::NamedBarrierArriveOp> { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::NamedBarrierArriveOp>:: + ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::NamedBarrierArriveOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getBar(), adaptor.getNumThreads()); + return success(); + } +}; + +struct NamedBarrierWaitOpConversion + : public ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::NamedBarrierWaitOp> { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::NamedBarrierWaitOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::NamedBarrierWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getBar(), adaptor.getNumThreads()); + return success(); + } +}; + +struct FenceAsyncSharedOpConversion + : public ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::FenceAsyncSharedOp> { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::FenceAsyncSharedOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::FenceAsyncSharedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getBCluster()); + return success(); + } +}; + +void populateBarrierOpToLLVMPatterns( + TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, + ModuleAllocation &allocation, PatternBenefit benefit) { + patterns.add(typeConverter, allocation, benefit); + patterns.add(typeConverter, allocation, benefit); + patterns.add(typeConverter, allocation, benefit); + patterns.add(typeConverter, allocation, benefit); + patterns.add(typeConverter, allocation, + benefit); + patterns.add(typeConverter, allocation, + benefit); + patterns.add(typeConverter, allocation, + benefit); +} diff --git a/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.h new file mode 100644 index 000000000000..1e8f53bc6571 --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_BARRIER_OP_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_BARRIER_OP_H + +#include "TritonGPUToLLVMBase.h" + +using namespace mlir; +using namespace mlir::triton; + +void populateBarrierOpToLLVMPatterns( + TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, + ModuleAllocation &allocation, PatternBenefit benefit); + +#endif diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index cd7452dda639..a1d069fbb248 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -1,8 +1,22 @@ +add_library(rocm_libraries SHARED IMPORTED ) +set_target_properties(rocm_libraries PROPERTIES IMPORTED_LOCATION ${ROCM_LIBRARIES}) + add_mlir_conversion_library(TritonGPUToLLVM + ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp + ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp + ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp + ConvertLayoutOpToLLVM.cpp + DotOpToLLVM/FMA.cpp + DotOpToLLVM/MMAv1.cpp + DotOpToLLVM/MMAv2.cpp + DotOpToLLVM/WGMMA.cpp + DotOpToLLVM.cpp + ElementwiseOpToLLVM.cpp + LoadStoreOpToLLVM.cpp + BarrierOpToLLVM.cpp TritonGPUToLLVM.cpp GCNAsmFormat.cpp PTXAsmFormat.cpp - TritonGPUToLLVMPass.cpp ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -21,9 +35,12 @@ add_mlir_conversion_library(TritonGPUToLLVM PTXAsmFormat.cpp ReduceOpToLLVM.cpp ScanOpToLLVM.cpp - Utility.cpp TypeConverter.cpp + Utility.cpp ViewOpToLLVM.cpp + TensorPtrOpsToLLVM.cpp + ClusterOpsToLLVM.cpp + RegReallocOpToLLVM.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM @@ -46,4 +63,7 @@ add_mlir_conversion_library(TritonGPUToLLVM TritonIR TritonGPUIR TritonGPUTransforms + TritonNvidiaGPUTransforms + NVGPUIR + rocm_libraries ) diff --git a/lib/Conversion/TritonGPUToLLVM/ClusterOpsToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ClusterOpsToLLVM.cpp new file mode 100644 index 000000000000..42bbf04839bb --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/ClusterOpsToLLVM.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "ClusterOpsToLLVM.h" +using namespace mlir; +using namespace mlir::triton; + +struct ClusterArriveOpConversion : public ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::ClusterArriveOp> { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::ClusterArriveOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::ClusterArriveOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op.getRelaxed()); + return success(); + } +}; + +struct ClusterWaitOpConversion : public ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::ClusterWaitOp> { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::ClusterWaitOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::ClusterWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op); + return success(); + } +}; + +void populateClusterOpsToLLVMPatterns( + TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, + ModuleAllocation &allocation, PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + return; +} diff --git a/lib/Conversion/TritonGPUToLLVM/ClusterOpsToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ClusterOpsToLLVM.h new file mode 100644 index 000000000000..693310afa39a --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/ClusterOpsToLLVM.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_CLUSTER_OPS_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_CLUSTER_OPS_H + +#include "TritonGPUToLLVMBase.h" + +using namespace mlir; +using namespace mlir::triton; + +void populateClusterOpsToLLVMPatterns( + TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, + ModuleAllocation &allocation, PatternBenefit benefit); + +#endif diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 1cb4488bc973..04a6455a4c6e 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -1,14 +1,20 @@ #include "ConvertLayoutOpToLLVM.h" #include "Utility.h" -using ::mlir::LLVM::delinearize; +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" + using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::LLVM::getStridesFromShapeAndOrder; using ::mlir::LLVM::linearize; + +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::LLVM::getStridesFromShapeAndOrder; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::getContigPerThread; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getShapePerCTATile; using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::isaDistributedLayout; @@ -82,6 +88,13 @@ struct ConvertLayoutOpConversion dstLayout.isa()) { return lowerSharedToDotOperand(op, adaptor, rewriter); } + // forwarding on mma->mma shortcut, lower distributed->distributed otherwise + if (srcLayout.isa() && dstLayout.isa()) { + if (isMmaToMmaShortcut(srcTy, dstTy)) { + rewriter.replaceOp(op, op.getSrc()); + return success(); + } + } if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) { return lowerDistributedToDistributed(op, adaptor, rewriter); } @@ -105,23 +118,25 @@ struct ConvertLayoutOpConversion } private: - SmallVector getMultiDimOffset(Attribute layout, Location loc, - ConversionPatternRewriter &rewriter, - unsigned elemId, RankedTensorType type, - ArrayRef multiDimCTAInRepId, - ArrayRef shapePerCTA) const { + SmallVector + getMultiDimOffset(Attribute layout, Location loc, + ConversionPatternRewriter &rewriter, unsigned elemId, + RankedTensorType type, + ArrayRef multiDimCTAInRepId, + ArrayRef shapePerCTATile) const { auto shape = type.getShape(); unsigned rank = shape.size(); if (auto blockedLayout = layout.dyn_cast()) { auto multiDimOffsetFirstElem = - emitBaseIndexForLayout(loc, rewriter, blockedLayout, type); + emitBaseIndexForLayout(loc, rewriter, blockedLayout, type, false); SmallVector multiDimOffset(rank); SmallVector multiDimElemId = getMultiDimIndex( elemId, getSizePerThread(layout), getOrder(layout)); for (unsigned d = 0; d < rank; ++d) { - multiDimOffset[d] = add(multiDimOffsetFirstElem[d], - i32_val(multiDimCTAInRepId[d] * shapePerCTA[d] + - multiDimElemId[d])); + multiDimOffset[d] = + add(multiDimOffsetFirstElem[d], + i32_val(multiDimCTAInRepId[d] * shapePerCTATile[d] + + multiDimElemId[d])); } return multiDimOffset; } @@ -143,7 +158,7 @@ struct ConvertLayoutOpConversion auto multiDimOffsetParent = getMultiDimOffset( parentEncoding, loc, rewriter, idxs[elemId], parentTy, sliceLayout.paddedShape(multiDimCTAInRepId), - sliceLayout.paddedShape(shapePerCTA)); + sliceLayout.paddedShape(shapePerCTATile)); SmallVector multiDimOffset(rank); for (unsigned d = 0; d < rank + 1; ++d) { if (d == dim) @@ -154,6 +169,8 @@ struct ConvertLayoutOpConversion return multiDimOffset; } if (auto mmaLayout = layout.dyn_cast()) { + auto shapePerCTA = getShapePerCTA(mmaLayout, shape); + auto instrShape = mmaLayout.getInstrShape(); SmallVector mmaColIdx(4); SmallVector mmaRowIdx(2); Value threadId = getThreadId(rewriter, loc); @@ -162,27 +179,35 @@ struct ConvertLayoutOpConversion Value laneId = urem(threadId, warpSize); Value warpId = udiv(threadId, warpSize); // TODO: fix the bug in MMAEncodingAttr document + SmallVector multiDimWarpId(2); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); - auto order = triton::gpu::getOrder(mmaLayout); - SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, order); + if (mmaLayout.isHopper()) { + multiDimWarpId[0] = urem(warpId, i32_val(warpsPerCTA[0])); + multiDimWarpId[1] = udiv(warpId, i32_val(warpsPerCTA[0])); + } else { + auto order = triton::gpu::getOrder(mmaLayout); + multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, order); + } Value _1 = i32_val(1); Value _2 = i32_val(2); Value _4 = i32_val(4); Value _8 = i32_val(8); Value _16 = i32_val(16); - if (mmaLayout.isAmpere()) { - multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 16)); - multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 8)); + if (mmaLayout.isAmpere() || mmaLayout.isHopper()) { + multiDimWarpId[0] = + urem(multiDimWarpId[0], i32_val(shapePerCTA[0] / instrShape[0])); + multiDimWarpId[1] = + urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / instrShape[1])); + Value mmaGrpId = udiv(laneId, _4); Value mmaGrpIdP8 = add(mmaGrpId, _8); Value mmaThreadIdInGrp = urem(laneId, _4); Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2); Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1); - Value rowWarpOffset = mul(multiDimWarpId[0], _16); + Value rowWarpOffset = mul(multiDimWarpId[0], i32_val(instrShape[0])); mmaRowIdx[0] = add(mmaGrpId, rowWarpOffset); mmaRowIdx[1] = add(mmaGrpIdP8, rowWarpOffset); - Value colWarpOffset = mul(multiDimWarpId[1], _8); + Value colWarpOffset = mul(multiDimWarpId[1], i32_val(instrShape[1])); mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset); mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset); } else if (mmaLayout.isVolta()) { @@ -193,13 +218,27 @@ struct ConvertLayoutOpConversion assert(rank == 2); SmallVector multiDimOffset(rank); - if (mmaLayout.isAmpere()) { + if (mmaLayout.isHopper()) { + unsigned elemIdRem4 = elemId % 4; + unsigned nGrpId = elemId / 4; + multiDimOffset[0] = elemIdRem4 < 2 ? mmaRowIdx[0] : mmaRowIdx[1]; + multiDimOffset[1] = elemIdRem4 % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1]; + multiDimOffset[1] = add(multiDimOffset[1], i32_val(8 * nGrpId)); + multiDimOffset[0] = + add(multiDimOffset[0], + i32_val(multiDimCTAInRepId[0] * shapePerCTATile[0])); + multiDimOffset[1] = + add(multiDimOffset[1], + i32_val(multiDimCTAInRepId[1] * shapePerCTATile[1])); + } else if (mmaLayout.isAmpere()) { multiDimOffset[0] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1]; multiDimOffset[1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1]; - multiDimOffset[0] = add( - multiDimOffset[0], i32_val(multiDimCTAInRepId[0] * shapePerCTA[0])); - multiDimOffset[1] = add( - multiDimOffset[1], i32_val(multiDimCTAInRepId[1] * shapePerCTA[1])); + multiDimOffset[0] = + add(multiDimOffset[0], + i32_val(multiDimCTAInRepId[0] * shapePerCTATile[0])); + multiDimOffset[1] = + add(multiDimOffset[1], + i32_val(multiDimCTAInRepId[1] * shapePerCTATile[1])); } else if (mmaLayout.isVolta()) { auto [isARow, isBRow, isAVec4, isBVec4, _] = mmaLayout.decodeVoltaLayoutStates(); @@ -214,11 +253,12 @@ struct ConvertLayoutOpConversion } #ifdef USE_ROCM if (auto mfmaLayout = layout.dyn_cast()) { - auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, type); + auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, type, false); SmallVector> offsets; assert(rank == 2); SmallVector multiDimOffset(rank); - emitMfmaOffsetForCTA(mfmaLayout, offsets, multiDimCTAInRepId[0], multiDimCTAInRepId[1]); + emitMfmaOffsetForCTA(mfmaLayout, offsets, multiDimCTAInRepId[0], + multiDimCTAInRepId[1]); multiDimOffset[0] = add(multiDimBase[0], i32_val(offsets[elemId][0])); multiDimOffset[1] = add(multiDimBase[1], i32_val(offsets[elemId][1])); return multiDimOffset; @@ -240,11 +280,12 @@ struct ConvertLayoutOpConversion auto rank = type.getRank(); auto sizePerThread = getSizePerThread(layout); auto accumSizePerThread = product(sizePerThread); - SmallVector numCTAs(rank); + SmallVector numCTATiles(rank); + auto shapePerCTATile = getShapePerCTATile(layout); auto shapePerCTA = getShapePerCTA(layout, type.getShape()); auto order = getOrder(layout); for (unsigned d = 0; d < rank; ++d) { - numCTAs[d] = ceil(type.getShape()[d], shapePerCTA[d]); + numCTATiles[d] = ceil(shapePerCTA[d], shapePerCTATile[d]); } auto elemTy = type.getElementType(); bool isInt1 = elemTy.isInteger(1); @@ -267,17 +308,16 @@ struct ConvertLayoutOpConversion } auto linearCTAId = - getLinearIndex(multiDimCTAId, numCTAs, order); + getLinearIndex(multiDimCTAId, numCTATiles, order); // TODO: This is actually redundant index calculation, we should // consider of caching the index calculation result in case // of performance issue observed. for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { SmallVector multiDimOffset = getMultiDimOffset(layout, loc, rewriter, elemId, type, - multiDimCTAInRepId, shapePerCTA); + multiDimCTAInRepId, shapePerCTATile); Value offset = linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd); - auto elemPtrTy = ptr_ty(llvmElemTy, 3); Value ptr = gep(elemPtrTy, smemBase, offset); auto vecTy = vec_ty(llvmElemTy, vec); @@ -334,7 +374,8 @@ struct ConvertLayoutOpConversion SmallVector numCTAs(rank, 1); SmallVector numCTAsEachRep(rank, 1); - SmallVector shapePerCTA = getShapePerCTA(layout, shape); + SmallVector shapePerCTATile = getShapePerCTATile(layout, shape); + SmallVector shapePerCTA = getShapePerCTA(layout, shape); auto elemTy = type.getElementType(); int ctaId = 0; @@ -364,7 +405,7 @@ struct ConvertLayoutOpConversion // duplicate in Volta. SmallVector multiDimOffset = getMultiDimOffset(layout, loc, rewriter, elemId, type, - multiDimCTAInRepId, shapePerCTA); + multiDimCTAInRepId, shapePerCTATile); coord2val[elemId] = std::make_pair(multiDimOffset, vals[elemId]); } @@ -372,7 +413,7 @@ struct ConvertLayoutOpConversion // do transpose auto aEncoding = DotOperandEncodingAttr::get(mma.getContext(), 0, mma, 0); - int numM = aEncoding.getMMAv1NumOuter(shape); + int numM = aEncoding.getMMAv1NumOuter(shapePerCTA); int numN = accumSizePerThread / numM; for (int r = 0; r < numM; r++) { @@ -411,6 +452,91 @@ struct ConvertLayoutOpConversion } } + LogicalResult + lowerDistToDistWithDistSmem(triton::gpu::ConvertLayoutOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + + Value src = op.getSrc(); + Value dst = op.getResult(); + auto srcTy = src.getType().cast(); + auto dstTy = dst.getType().cast(); + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + auto srcShapePerCTA = getShapePerCTA(srcTy); + auto srcCTAsPerCGA = triton::gpu::getCTAsPerCGA(srcLayout); + auto srcCTAOrder = triton::gpu::getCTAOrder(srcLayout); + unsigned rank = srcShapePerCTA.size(); + + auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); + auto elemPtrTy = ptr_ty(llvmElemTy, 3); + + Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); + smemBase = bitcast(smemBase, elemPtrTy); + auto smemShape = convertType(srcShapePerCTA); + + // Store to local shared memory + { + auto inVals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(), + rewriter, srcTy); + auto inIndices = + emitIndices(loc, rewriter, srcLayout, srcTy, /*withCTAOffset*/ false); + + assert(inIndices.size() == inVals.size() && + "Unexpected number of indices emitted"); + + for (unsigned i = 0; i < inIndices.size(); ++i) { + Value offset = linearize(rewriter, loc, inIndices[i], smemShape); + Value ptr = gep(elemPtrTy, smemBase, offset); + store(inVals[i], ptr); + } + } + + // Cluster barrier + rewriter.create(loc, false); + rewriter.create(loc); + + // Load from remote shared memory + { + SmallVector srcShapePerCTACache; + for (unsigned i = 0; i < rank; ++i) + srcShapePerCTACache.push_back(i32_val(srcShapePerCTA[i])); + + SmallVector outVals; + auto outIndices = + emitIndices(loc, rewriter, dstLayout, dstTy, /*withCTAOffset*/ true); + + for (unsigned i = 0; i < outIndices.size(); ++i) { + auto coord = outIndices[i]; + assert(coord.size() == rank && "Unexpected rank of index emitted"); + + SmallVector multiDimCTAId, localCoord; + for (unsigned d = 0; d < rank; ++d) { + multiDimCTAId.push_back(udiv(coord[d], srcShapePerCTACache[d])); + localCoord.push_back(urem(coord[d], srcShapePerCTACache[d])); + } + + Value remoteCTAId = + linearize(rewriter, loc, multiDimCTAId, srcCTAsPerCGA, srcCTAOrder); + Value localOffset = linearize(rewriter, loc, localCoord, smemShape); + + Value ptr = gep(elemPtrTy, smemBase, localOffset); + outVals.push_back(load_dsmem(ptr, remoteCTAId)); + } + + Value result = + getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + } + + // Cluster barrier + rewriter.create(loc, false); + rewriter.create(loc); + + return success(); + } + // blocked/mma -> blocked/mma. // Data padding in shared memory to avoid bank conflict. LogicalResult @@ -424,6 +550,10 @@ struct ConvertLayoutOpConversion auto dstTy = dst.getType().cast(); Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); + + if (shouldUseDistSmem(srcLayout, dstLayout)) + return lowerDistToDistWithDistSmem(op, adaptor, rewriter); + auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); auto elemPtrTy = ptr_ty(llvmElemTy, 3); @@ -435,8 +565,9 @@ struct ConvertLayoutOpConversion SmallVector outNumCTAsEachRep(rank); SmallVector inNumCTAs(rank); SmallVector outNumCTAs(rank); - auto srcShapePerCTA = getShapePerCTA(srcLayout, srcTy.getShape()); - auto dstShapePerCTA = getShapePerCTA(dstLayout, shape); + auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); + auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape); + auto shapePerCTA = getShapePerCTA(srcLayout, shape); // For Volta, all the coords for a CTA are calculated. bool isSrcMmaV1{}, isDstMmaV1{}; @@ -456,15 +587,17 @@ struct ConvertLayoutOpConversion } for (unsigned d = 0; d < rank; ++d) { - unsigned inPerCTA = std::min(shape[d], srcShapePerCTA[d]); - unsigned outPerCTA = std::min(shape[d], dstShapePerCTA[d]); + unsigned inPerCTA = + std::min(shapePerCTA[d], srcShapePerCTATile[d]); + unsigned outPerCTA = + std::min(shapePerCTA[d], dstShapePerCTATile[d]); unsigned maxPerCTA = std::max(inPerCTA, outPerCTA); - numReplicates[d] = ceil(shape[d], maxPerCTA); + numReplicates[d] = ceil(shapePerCTA[d], maxPerCTA); inNumCTAsEachRep[d] = maxPerCTA / inPerCTA; outNumCTAsEachRep[d] = maxPerCTA / outPerCTA; assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0); - inNumCTAs[d] = ceil(shape[d], inPerCTA); - outNumCTAs[d] = ceil(shape[d], outPerCTA); + inNumCTAs[d] = ceil(shapePerCTA[d], inPerCTA); + outNumCTAs[d] = ceil(shapePerCTA[d], outPerCTA); } // Potentially we need to store for multiple CTAs in this replication auto accumNumReplicates = product(numReplicates); @@ -473,7 +606,8 @@ struct ConvertLayoutOpConversion unsigned inVec = 0; unsigned outVec = 0; auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec); - if (getElementTypeOrSelf(op.getType()).isa()) { + if (getElementTypeOrSelf(op.getType()) + .isa()) { assert(inVec % 4 == 0 && "conversion not supported for FP8E4M3B15"); assert(outVec % 4 == 0 && "conversion not supported for FP8E4M3B15"); } @@ -485,8 +619,25 @@ struct ConvertLayoutOpConversion for (unsigned repId = 0; repId < accumNumReplicates; ++repId) { auto multiDimRepId = getMultiDimIndex(repId, numReplicates, outOrd); - if (repId != 0) - barrier(); + if (repId != 0) { + // TODO[shuhaoj]: change hard code style of numThreads. Hide async + // attr. Better way to determine barId (number of agents are limited). + if (auto optionalAgentId = getWSAgentId(op)) { + int agentId = *optionalAgentId, roleId = 0; + if (auto optionalRoleId = getWSRoleId(op)) + roleId = *optionalRoleId; + int barId = agentId + roleId + nameBarrierIdBegin; + assert(barId < nameBarrierIdEnd); + auto bar = rewriter.create( + loc, i32_ty, rewriter.getI32IntegerAttr(barId)); + auto kNumThreads = rewriter.create( + loc, i32_ty, rewriter.getI32IntegerAttr(128)); + rewriter.create(loc, bar, + kNumThreads); + } else { + barrier(); + } + } if (srcLayout.isa() || srcLayout.isa() || #ifdef USE_ROCM @@ -506,7 +657,23 @@ struct ConvertLayoutOpConversion return failure(); } - barrier(); + // TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent + // attr. Better way to determine barId (number of agents are limited). + if (auto optionalAgentId = getWSAgentId(op)) { + int agentId = *optionalAgentId, roleId = 0; + if (auto optionalRoleId = getWSRoleId(op)) + roleId = *optionalRoleId; + int barId = agentId + roleId + nameBarrierIdBegin; + assert(barId < nameBarrierIdEnd); + auto bar = rewriter.create( + loc, i32_ty, rewriter.getI32IntegerAttr(barId)); + auto kNumThreads = rewriter.create( + loc, i32_ty, rewriter.getI32IntegerAttr(128)); + rewriter.create(loc, bar, + kNumThreads); + } else { + barrier(); + } if (dstLayout.isa() || dstLayout.isa() || #ifdef USE_ROCM @@ -580,7 +747,7 @@ struct ConvertLayoutOpConversion auto srcTy = src.getType().cast(); auto srcShape = srcTy.getShape(); auto dstTy = dst.getType().cast(); - auto dstShape = dstTy.getShape(); + auto dstShapePerCTA = triton::gpu::getShapePerCTA(dstTy); assert(srcShape.size() == 2 && "Unexpected rank of ConvertLayout(blocked->shared)"); auto srcLayout = srcTy.getEncoding(); @@ -592,13 +759,102 @@ struct ConvertLayoutOpConversion auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); smemBase = bitcast(smemBase, elemPtrTy); - auto dstStrides = - getStridesFromShapeAndOrder(dstShape, outOrd, loc, rewriter); - auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy); - storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices, dst, - smemBase, elemTy, loc, rewriter); + int32_t elemSize = elemTy.getIntOrFloatBitWidth(); + auto mmaLayout = srcLayout.dyn_cast(); + unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); + if (mmaLayout && mmaLayout.isHopper() && elemSize == 16 && + inOrd == outOrd && numElems >= 16) { + auto inVals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(), + rewriter, srcTy); + + auto srcShapePerCTA = getShapePerCTA(mmaLayout, srcShape); + auto instrShape = mmaLayout.getInstrShape(); + auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); + uint32_t repM = + ceil(srcShapePerCTA[0], instrShape[0] * warpsPerCTA[0]); + uint32_t numElemsPerRep = numElems / repM; + // rowStride in bytes + uint32_t rowStrideInBytes = dstShapePerCTA[outOrd[0]] * 2; + uint32_t swizzlingByteWidth = rowStrideInBytes; + if (swizzlingByteWidth > 128) + swizzlingByteWidth = 128; + + unsigned numElemsPerSwizzlingRow = swizzlingByteWidth * 8 / elemSize; + unsigned leadingDimOffset = + numElemsPerSwizzlingRow * srcShapePerCTA[outOrd[1]]; + + auto ptrI8SharedTy = LLVM::LLVMPointerType::get( + typeConverter->convertType(rewriter.getI8Type()), 3); + + uint32_t rowsPerRep = getShapePerCTATile(mmaLayout)[0]; + + Value threadId = getThreadId(rewriter, loc); + Value warpId = udiv(threadId, i32_val(32)); + Value warpId0 = urem(urem(warpId, i32_val(warpsPerCTA[0])), + i32_val(srcShape[0] / instrShape[0])); + + unsigned inVec = + inOrd == outOrd ? triton::gpu::getContigPerThread(mmaLayout)[inOrd[0]] + : 1; + unsigned outVec = dstSharedLayout.getVec(); + unsigned minVec = std::min(outVec, inVec); + assert(minVec == 2); + auto wordTy = vec_ty(elemTy, minVec); + + for (int rep = 0; rep < repM; ++rep) { + Value rowOfWarp = add(mul(warpId0, i32_val(instrShape[0])), + i32_val(rep * rowsPerRep)); + uint32_t elemIdxOffset = rep * numElemsPerRep; + + for (unsigned idx = 0; idx < numElemsPerRep; idx += 8) { + uint32_t elemIdx = elemIdxOffset + idx; + + Value offset = rewriter.create( + loc, i32_ty, threadId, rowOfWarp, i32_val(idx), leadingDimOffset, + numElemsPerSwizzlingRow, true); + + Value addr = gep(elemPtrTy, smemBase, offset); + + Value words[4]; + for (unsigned i = 0; i < 8; ++i) { + if (i % minVec == 0) + words[i / 2] = undef(wordTy); + words[i / 2] = insert_element( + wordTy, words[i / 2], inVals[elemIdx + i], i32_val(i % minVec)); + } + + rewriter.create( + loc, bitcast(addr, ptrI8SharedTy), + ValueRange{bitcast(words[0], i32_ty), bitcast(words[1], i32_ty), + bitcast(words[2], i32_ty), bitcast(words[3], i32_ty)}); + } + } + // TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent + // attr. Better way to determine barId (number of agents are limited). + if (auto optionalAgentId = getWSAgentId(op)) { + int agentId = *optionalAgentId, roleId = 0; + if (auto optionalRoleId = getWSRoleId(op)) + roleId = *optionalRoleId; + int barId = agentId + roleId + nameBarrierIdBegin; + assert(barId < nameBarrierIdEnd); + auto bar = rewriter.create( + loc, i32_ty, rewriter.getI32IntegerAttr(barId)); + auto kNumThreads = rewriter.create( + loc, i32_ty, rewriter.getI32IntegerAttr(128)); + rewriter.create(loc, bar, + kNumThreads); + } else { + barrier(); + } + } else { + auto dstStrides = + getStridesFromShapeAndOrder(dstShapePerCTA, outOrd, loc, rewriter); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy, false); + storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices, + dst, smemBase, elemTy, loc, rewriter); + } auto smemObj = - SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter); + SharedMemoryObject(smemBase, dstShapePerCTA, outOrd, loc, rewriter); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); rewriter.replaceOp(op, retVal); return success(); @@ -807,19 +1063,16 @@ struct ConvertLayoutOpConversion auto loc = op.getLoc(); Value src = op.getSrc(); Value dst = op.getResult(); + bool isMMA = supportMMA(dst, mmaLayout.getVersionMajor()); auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter); Value res; - if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2 - res = SharedToDotOperandMMAv2::convertLayout( dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout, - smemObj, getTypeConverter(), tid_val()); - - } else if (!isOuter && mmaLayout.isVolta() && - supportMMA(dst, mmaLayout.getVersionMajor())) { // tensor core v1 + smemObj, getTypeConverter(), getThreadId(rewriter, loc)); + } else if (!isOuter && mmaLayout.isVolta() && isMMA) { // tensor core v1 bool isMMAv1Row = dotOperandLayout.getMMAv1IsRow(); auto srcSharedLayout = src.getType() .cast() @@ -841,10 +1094,11 @@ struct ConvertLayoutOpConversion } return res; } -}; // namespace triton::gpu::ConvertLayoutOp +}; // namespace triton::gpu::ConvertLayoutOp> void populateConvertLayoutOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit) { diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h index c8f3396b98cf..9cbf49573deb 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h @@ -10,6 +10,7 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr; void populateConvertLayoutOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit); diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp index 1507dd543aef..8ca6dc957c74 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -2,8 +2,10 @@ #include "../Utility.h" using ValueTable = std::map, Value>; +using ::mlir::LLVM::delinearize; using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::LLVM::getStridesFromShapeAndOrder; +using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::getContigPerThread; using ::mlir::triton::gpu::getOrder; @@ -14,31 +16,32 @@ using ::mlir::triton::gpu::isaDistributedLayout; using ::mlir::triton::gpu::SharedEncodingAttr; SmallVector -getThreadIds(Value threadId, ArrayRef shapePerCTA, +getThreadIds(Value threadId, ArrayRef shapePerCTATile, ArrayRef sizePerThread, ArrayRef order, ConversionPatternRewriter &rewriter, Location loc) { int dim = order.size(); SmallVector threadIds(dim); for (unsigned k = 0; k < dim - 1; k++) { - Value dimK = i32_val(shapePerCTA[order[k]] / sizePerThread[order[k]]); + Value dimK = i32_val(shapePerCTATile[order[k]] / sizePerThread[order[k]]); Value rem = urem(threadId, dimK); threadId = udiv(threadId, dimK); threadIds[order[k]] = rem; } - Value dimK = i32_val(shapePerCTA[order[dim - 1]]); + Value dimK = i32_val(shapePerCTATile[order[dim - 1]]); threadIds[order[dim - 1]] = urem(threadId, dimK); return threadIds; } -int getShapePerCTAForMN(BlockedEncodingAttr layout, bool isM) { +// Get shapePerCTATile for M or N axis. +int getShapePerCTATileForMN(BlockedEncodingAttr layout, bool isM) { auto order = layout.getOrder(); - auto shapePerCTA = getShapePerCTA(layout); + auto shapePerCTATile = getShapePerCTATile(layout); - int mShapePerCTA = - order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; - int nShapePerCTA = - order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; - return isM ? mShapePerCTA : nShapePerCTA; + int mShapePerCTATile = + order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int nShapePerCTATile = + order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + return isM ? mShapePerCTATile : nShapePerCTATile; } // Get sizePerThread for M or N axis. @@ -91,7 +94,7 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, ConversionPatternRewriter &rewriter) { auto aTensorTy = A.getType().cast(); auto aLayout = aTensorTy.getEncoding().cast(); - auto aShape = aTensorTy.getShape(); + auto aShapePerCTA = getShapePerCTA(aTensorTy); auto aOrder = aLayout.getOrder(); auto order = dLayout.getOrder(); @@ -104,10 +107,10 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, Value strideA0 = isARow ? strideAK : strideAM; Value strideA1 = isARow ? strideAM : strideAK; int aNumPtr = 8; - int K = aShape[1]; - int M = aShape[0]; + int K = aShapePerCTA[1]; + int M = aShapePerCTA[0]; - auto shapePerCTA = getShapePerCTA(dLayout); + auto shapePerCTATile = getShapePerCTATile(dLayout); auto sizePerThread = getSizePerThread(dLayout); Value _0 = i32_val(0); @@ -115,8 +118,8 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, Value mContig = i32_val(sizePerThread[order[1]]); // threadId in blocked layout - auto threadIds = - getThreadIds(thread, shapePerCTA, sizePerThread, order, rewriter, loc); + auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order, + rewriter, loc); Value threadIdM = threadIds[0]; Value offA0 = isARow ? _0 : mul(threadIdM, mContig); @@ -135,11 +138,11 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, SmallVector vas; - int mShapePerCTA = getShapePerCTAForMN(dLayout, true /*isM*/); + int mShapePerCTATile = getShapePerCTATileForMN(dLayout, true /*isM*/); int mSizePerThread = getSizePerThreadForMN(dLayout, true /*isM*/); for (unsigned k = 0; k < K; ++k) - for (unsigned m = 0; m < M; m += mShapePerCTA) + for (unsigned m = 0; m < M; m += mShapePerCTATile) for (unsigned mm = 0; mm < mSizePerThread; ++mm) { Value offset = add(mul(i32_val(m + mm), strideAM), mul(i32_val(k), strideAK)); @@ -156,7 +159,7 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, ConversionPatternRewriter &rewriter) { auto bTensorTy = B.getType().cast(); auto bLayout = bTensorTy.getEncoding().cast(); - auto bShape = bTensorTy.getShape(); + auto bShapePerCTA = getShapePerCTA(bTensorTy); auto bOrder = bLayout.getOrder(); auto order = dLayout.getOrder(); @@ -169,10 +172,10 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, Value strideB0 = isBRow ? strideBN : strideBK; Value strideB1 = isBRow ? strideBK : strideBN; int bNumPtr = 8; - int K = bShape[0]; - int N = bShape[1]; + int K = bShapePerCTA[0]; + int N = bShapePerCTA[1]; - auto shapePerCTA = getShapePerCTA(dLayout); + auto shapePerCTATile = getShapePerCTATile(dLayout); auto sizePerThread = getSizePerThread(dLayout); Value _0 = i32_val(0); @@ -180,8 +183,8 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, Value nContig = i32_val(sizePerThread[order[0]]); // threadId in blocked layout - auto threadIds = - getThreadIds(thread, shapePerCTA, sizePerThread, order, rewriter, loc); + auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order, + rewriter, loc); Value threadIdN = threadIds[1]; Value offB0 = isBRow ? mul(threadIdN, nContig) : _0; @@ -200,11 +203,11 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, SmallVector vbs; - int nShapePerCTA = getShapePerCTAForMN(dLayout, false /*isM*/); + int nShapePerCTATile = getShapePerCTATileForMN(dLayout, false /*isM*/); int nSizePerThread = getSizePerThreadForMN(dLayout, false /*isM*/); for (unsigned k = 0; k < K; ++k) - for (unsigned n = 0; n < N; n += nShapePerCTA) + for (unsigned n = 0; n < N; n += nShapePerCTATile) for (unsigned nn = 0; nn < nSizePerThread; ++nn) { Value offset = add(mul(i32_val(n + nn), strideBN), mul(i32_val(k), strideBK)); diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 01fa4c238642..1dad5ab58a10 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -15,6 +15,7 @@ Type getShemPtrTy(Type elemTy) { auto ctx = elemTy.getContext(); return ptr_ty(type::i16Ty(ctx), 3); } + return ptr_ty(elemTy, 3); } @@ -106,6 +107,7 @@ swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row, * @param reps number of instructions repretition to fully cover dot operand * @param smemStrides strides in LDS tensor * @param loadVecSize number of elements loaded by one operation + * @param iNonKDim non-K dimension of dot operand * @return vector (i-th element corresponds to i-th load instruction) of * 2-element vectors(tensor row and col). */ @@ -114,7 +116,7 @@ computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc, const ArrayRef &elemsPerInstr, Value waveId, Value laneId, int warpsPerGroup, int numOfElems, ArrayRef reps, ArrayRef smemOffsets, - int loadVecSize) { + int loadVecSize, unsigned iNonKDim) { auto numM = reps[0]; auto numK = reps[1]; const int loadsPerThread = numOfElems / loadVecSize; @@ -123,6 +125,7 @@ computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc, Value _0 = i32_val(0); Value _32 = i32_val(32); + Value nonKDim = i32_val(iNonKDim); for (int block = 0; block < numM; ++block) { Value blockVOffset = i32_val(block * elemsPerInstr[0] * warpsPerGroup); @@ -133,8 +136,13 @@ computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc, Value tileVOffset = _0; Value tileHOffset = i32_val(tile * elemsPerInstr[1]); - Value laneVOffset = urem(laneId, _32); - Value laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0); + Value laneVOffset = urem(laneId, nonKDim); + Value laneHOffset; + if (iNonKDim == 32) + laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0); + else + laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems)); + for (int loadId = 0; loadId < loadsPerThread; ++loadId) { Value elemVOffset = _0; Value elemHOffset = i32_val(loadId * loadVecSize); @@ -175,7 +183,7 @@ computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc, const ArrayRef &elemsPerInstr, Value waveId, Value laneId, int warpsPerGroup, int numOfElems, ArrayRef reps, SharedMemoryObject smemObj, - SharedEncodingAttr srcLayout) { + SharedEncodingAttr srcLayout, unsigned nonKDim) { SmallVector strides{smemObj.strides[0], smemObj.strides[1]}; SmallVector offsets{smemObj.offsets[0], smemObj.offsets[1]}; @@ -189,7 +197,7 @@ computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc, auto mapping = computeTensorElemMapping(rewriter, loc, elemsPerInstr, waveId, laneId, warpsPerGroup, numOfElems, - reps, offsets, vectorSize); + reps, offsets, vectorSize, nonKDim); llvm::SmallVector aOffsets(mapping.size()); for (int i = 0; i < mapping.size(); ++i) { Value row = mapping[i][0]; @@ -204,7 +212,7 @@ computeOffsetsBType(ConversionPatternRewriter &rewriter, Location loc, const ArrayRef &elemsPerInstr, Value waveId, Value laneId, int warpsPerGroup, int numOfElems, ArrayRef reps, SharedMemoryObject smemObj, - SharedEncodingAttr srcLayout) { + SharedEncodingAttr srcLayout, unsigned nonKDim) { // transpose reps and offsets, because operand B has layout equal to // transposed operand A layout SmallVector tElemsPerInstr{elemsPerInstr[1], elemsPerInstr[0]}; @@ -221,7 +229,7 @@ computeOffsetsBType(ConversionPatternRewriter &rewriter, Location loc, auto mapping = computeTensorElemMapping(rewriter, loc, tElemsPerInstr, waveId, laneId, warpsPerGroup, numOfElems, - tReps, toffsets, vectorSize); + tReps, toffsets, vectorSize, nonKDim); llvm::SmallVector bOffsets(mapping.size()); for (int i = 0; i < mapping.size(); ++i) { // swap row and col, because operand B layout is a transposed operand A @@ -410,7 +418,8 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread, TritonGPUToLLVMTypeConverter *typeConverter, Value tensor, const SharedMemoryObject &smemObj) { auto mfmaLayout = encoding.getParent().cast(); - assert(mfmaLayout.getNonKDim() == 32); + auto nonKDim = mfmaLayout.getNonKDim(); + assert(nonKDim == 32 || nonKDim == 16); auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); auto aTensorTy = tensor.getType().cast(); @@ -429,19 +438,20 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread, auto numRepK = numReps[1]; unsigned iWaveSize = triton::gpu::getWarpSize(mfmaLayout); + assert(iWaveSize == 64); Value waveSize = i32_val(iWaveSize); Value wave = udiv(thread, waveSize); Value lane = urem(thread, waveSize); Value waveM = getWaveM(rewriter, loc, wave, warpsPerCTA, mfmaInstrM, shape[0]); - int numOfElems = - std::max(mfmaInstrM * mfmaInstrK / iWaveSize /*wave size*/, 1); + int numOfElems = mfmaInstrM * mfmaInstrK / iWaveSize; + assert(numOfElems >= 1); unsigned int maxNumWarps = shape[0] / mfmaInstrM; int warpsPerGroupM = std::min(warpsPerCTA[0], maxNumWarps); + aElemTy = typeConverter->convertType(aElemTy); SmallVector ha; - if (fastPathAvailable(smemObj, sharedLayout, mfmaLayout)) { Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); SmallVector offsets; @@ -459,7 +469,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread, Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter); Type smemPtrTy = getShemPtrTy(aElemTy); - Type resElemTy = aElemTy.isBF16() ? i16_ty : aElemTy; + Type resElemTy = typeConverter->convertType(aElemTy); int loadsPerThread = offsets.size() / (numRepM * numRepK); const int elemsPerLoad = numOfElems / loadsPerThread; @@ -497,10 +507,10 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread, } else { // normal path SmallVector offsets = computeOffsetsAType( rewriter, loc, aElemsPerInstr, waveM, lane, warpsPerGroupM, numOfElems, - numReps, smemObj, sharedLayout); + numReps, smemObj, sharedLayout, nonKDim); Value smemBase = computeBasePtr(rewriter, loc, smemObj); - Type resElemTy = aElemTy.isBF16() ? i16_ty : aElemTy; + Type resElemTy = typeConverter->convertType(aElemTy); Type smemPtrTy = getShemPtrTy(aElemTy); @@ -550,7 +560,8 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread, TritonGPUToLLVMTypeConverter *typeConverter, Value tensor, const SharedMemoryObject &smemObj) { auto mfmaLayout = encoding.getParent().cast(); - assert(mfmaLayout.getNonKDim() == 32); + auto nonKDim = mfmaLayout.getNonKDim(); + assert(nonKDim == 32 || nonKDim == 16); auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); auto bTensorTy = tensor.getType().cast(); @@ -568,26 +579,21 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread, auto numRepN = numReps[1]; unsigned iWaveSize = triton::gpu::getWarpSize(mfmaLayout); + assert(iWaveSize == 64); Value waveSize = i32_val(iWaveSize); Value wave = udiv(thread, waveSize); Value lane = urem(thread, waveSize); Value waveN = getWaveN(rewriter, loc, wave, warpsPerCTA, mfmaInstrN, shape[1]); - int numOfElems = - std::max(mfmaInstrK * mfmaInstrN / iWaveSize /*wave size*/, 1); - - int macroTileM = std::max(shape[0] / (warpsPerCTA[0] * 32), 1); - int wptM = std::min(warpsPerCTA[0], macroTileM); - int macroTileN = std::max(shape[1] / (warpsPerCTA[1] * 32), 1); - int wptN = std::min(warpsPerCTA[1], macroTileN); - int wpt = std::max(wptM, wptN); + int numOfElems = mfmaInstrK * mfmaInstrN / iWaveSize; + assert(numOfElems >= 1); unsigned int maxNumWarps = shape[1] / mfmaInstrN; int warpsPerGroupN = std::min(warpsPerCTA[1], maxNumWarps); + bElemTy = typeConverter->convertType(bElemTy); SmallVector hb; - if (fastPathAvailable(smemObj, sharedLayout, mfmaLayout)) { Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); @@ -608,8 +614,7 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread, Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter); - Type resElemTy = bElemTy.isBF16() ? i16_ty : bElemTy; - + Type resElemTy = typeConverter->convertType(bElemTy); Type smemPtrTy = getShemPtrTy(bElemTy); const int loadsPerThread = offsets.size() / (numRepN * numRepK); @@ -648,12 +653,10 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread, } else { // normal path llvm::SmallVector offsets = computeOffsetsBType( rewriter, loc, bElemsPerInstr, waveN, lane, warpsPerGroupN, numOfElems, - numReps, smemObj, sharedLayout); + numReps, smemObj, sharedLayout, nonKDim); Value smemBase = computeBasePtr(rewriter, loc, smemObj); - - Type resElemTy = bElemTy.isBF16() ? i16_ty : bElemTy; - + Type resElemTy = typeConverter->convertType(bElemTy); Type smemPtrTy = getShemPtrTy(bElemTy); int loadsPerThread = offsets.size() / (numReps[0] * numReps[1]); diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp index 8a9efb0e09a1..89165a7b6eab 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp @@ -203,8 +203,8 @@ static Value loadA(Value tensor, const SharedMemoryObject &smemObj, SmallVector elems; elems.reserve(has.size() * 2); for (auto item : has) { // has is a map, the key should be ordered. - elems.push_back(item.second.first); - elems.push_back(item.second.second); + elems.push_back(bitcast(item.second.first, i32_ty)); + elems.push_back(bitcast(item.second.second, i32_ty)); } Value res = typeConverter->packLLElements(loc, elems, rewriter, resultTy); @@ -327,8 +327,8 @@ static Value loadB(Value tensor, const SharedMemoryObject &smemObj, SmallVector elems; for (auto &item : hbs) { // has is a map, the key should be ordered. - elems.push_back(item.second.first); - elems.push_back(item.second.second); + elems.push_back(bitcast(item.second.first, i32_ty)); + elems.push_back(bitcast(item.second.second, i32_ty)); } Value res = typeConverter->packLLElements(loc, elems, rewriter, resultTy); diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index 05fd4a5bff16..e939e2fe9bbe 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -19,7 +19,7 @@ using ::mlir::triton::gpu::SharedEncodingAttr; // Data loader for mma.16816 instruction. class MMA16816SmemLoader { public: - MMA16816SmemLoader(int warpsPerTile, ArrayRef order, + MMA16816SmemLoader(int nPerWarp, int warpsPerTile, ArrayRef order, ArrayRef warpsPerCTA, uint32_t kOrder, int kWidth, ArrayRef smemStrides, ArrayRef tileShape, ArrayRef instrShape, @@ -93,6 +93,8 @@ class MMA16816SmemLoader { int inWarpMatOffset; // Offset in number of matrices to increment on non-k dim across warps int warpMatOffset; + + int nPerWarp; }; SmallVector @@ -131,10 +133,18 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane, // address (s0,s1) annotates. Value matOff[2]; - matOff[kOrder ^ 1] = add( - mul(warpId, i32_val(warpMatOffset)), // warp offset (kOrder=1) - mul(nkMatArr, - i32_val(inWarpMatOffset))); // matrix offset inside a warp (kOrder=1) + // When B's shape(k, n) is (16, 8) and ldmatrix.x4 is used, the shared memory + // access will be out of bound. In the future we should change this case to + // ldmatrix.x2 + if (kOrder == 0 && nPerWarp == 8) { + matOff[kOrder ^ 1] = mul(warpId, i32_val(warpMatOffset)); + } else { + matOff[kOrder ^ 1] = add( + mul(warpId, i32_val(warpMatOffset)), // warp offset (kOrder=1) + mul(nkMatArr, + i32_val( + inWarpMatOffset))); // matrix offset inside a warp (kOrder=1) + } matOff[kOrder] = kMatArr; // Physical offset (before swizzling) @@ -390,13 +400,13 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef ptrs, Type matTy, } MMA16816SmemLoader::MMA16816SmemLoader( - int warpsPerTile, ArrayRef order, ArrayRef warpsPerCTA, - uint32_t kOrder, int kWidth, ArrayRef smemStrides, - ArrayRef tileShape, ArrayRef instrShape, - ArrayRef matShape, int perPhase, int maxPhase, int elemBytes, - ConversionPatternRewriter &rewriter, + int nPerWarp, int warpsPerTile, ArrayRef order, + ArrayRef warpsPerCTA, uint32_t kOrder, int kWidth, + ArrayRef smemStrides, ArrayRef tileShape, + ArrayRef instrShape, ArrayRef matShape, int perPhase, + int maxPhase, int elemBytes, ConversionPatternRewriter &rewriter, TritonGPUToLLVMTypeConverter *typeConverter, const Location &loc) - : order(order.begin(), order.end()), + : nPerWarp(nPerWarp), order(order.begin(), order.end()), warpsPerCTA(warpsPerCTA.begin(), warpsPerCTA.end()), kOrder(kOrder), kWidth(kWidth), tileShape(tileShape.begin(), tileShape.end()), instrShape(instrShape.begin(), instrShape.end()), @@ -490,6 +500,7 @@ std::function getLoadMatrixFn( bool isA, TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc) { auto tensorTy = tensor.getType().cast(); + auto shapePerCTA = getShapePerCTA(tensorTy); Type eltTy = tensorTy.getElementType(); // We assumes that the input operand of Dot should be from shared layout. // TODO(Superjomn) Consider other layouts if needed later. @@ -500,24 +511,19 @@ std::function getLoadMatrixFn( const int elemBytes = tensorTy.getElementTypeBitWidth() / 8; auto order = sharedLayout.getOrder(); - if (tensor.getType() - .cast() - .getElementType() - .isa()) { - bool noTrans = (isA ^ order[0] == 0); - assert(noTrans && "float8e4b15 must have row-col layout"); - } - if (kWidth != (4 / elemBytes)) assert(vecPhase == 1 || vecPhase == 4 * kWidth); + int nPerWarp = + std::max(shapePerCTA[1] / mmaLayout.getWarpsPerCTA()[1], 8); + // (a, b) is the coordinate. auto load = [=, &rewriter, &vals](int a, int b) { - MMA16816SmemLoader loader( - warpsPerTile, sharedLayout.getOrder(), mmaLayout.getWarpsPerCTA(), - kOrder, kWidth, smemObj.strides, tensorTy.getShape() /*tileShape*/, - instrShape, matShape, perPhase, maxPhase, elemBytes, rewriter, - typeConverter, loc); + MMA16816SmemLoader loader(nPerWarp, warpsPerTile, sharedLayout.getOrder(), + mmaLayout.getWarpsPerCTA(), kOrder, kWidth, + smemObj.strides, shapePerCTA /*tileShape*/, + instrShape, matShape, perPhase, maxPhase, + elemBytes, rewriter, typeConverter, loc); // Offset of a slice within the original tensor in shared memory Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); SmallVector offs = @@ -559,17 +565,15 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, Value tensor, TritonGPUToLLVMTypeConverter *typeConverter, Value thread, bool isA) { auto tensorTy = tensor.getType().cast(); + auto shapePerCTA = getShapePerCTA(tensorTy); int bitwidth = tensorTy.getElementTypeBitWidth(); auto mmaLayout = encoding.getParent().cast(); - SmallVector shape(tensorTy.getShape().begin(), - tensorTy.getShape().end()); - ValueTable vals; int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth; int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth; - auto numRep = encoding.getMMAv2Rep(tensorTy.getShape(), bitwidth); + auto numRep = encoding.getMMAv2Rep(shapePerCTA, bitwidth); int kWidth = encoding.getKWidth(); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); @@ -579,14 +583,14 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, Value tensor, SmallVector multiDimWarpId = delinearize(rewriter, loc, warp, warpsPerCTA, order); - Value warpM = urem(multiDimWarpId[0], i32_val(shape[0] / 16)); - Value warpN = urem(multiDimWarpId[1], i32_val(shape[1] / 8)); + Value warpM = urem(multiDimWarpId[0], i32_val(shapePerCTA[0] / 16)); + Value warpN = urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / 8)); int warpsPerTile; if (isA) - warpsPerTile = std::min(warpsPerCTA[0], shape[0] / 16); + warpsPerTile = std::min(warpsPerCTA[0], shapePerCTA[0] / 16); else - warpsPerTile = std::min(warpsPerCTA[1], shape[1] / 16); + warpsPerTile = std::min(warpsPerCTA[1], shapePerCTA[1] / 16); std::function loadFn; if (isA) diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp index 43126a3d5038..bcda6ccf5d7c 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp @@ -4,7 +4,9 @@ using namespace mlir; using namespace mlir::triton; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::MmaEncodingAttr; LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, @@ -15,6 +17,10 @@ LogicalResult convertMMA884(triton::DotOp op, triton::DotOp::Adaptor adaptor, TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter); +LogicalResult convertMMA1688(triton::DotOp op, triton::DotOp::Adaptor adaptor, + TritonGPUToLLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); + LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor, TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter); @@ -24,6 +30,15 @@ LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter); #endif +LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, + TritonGPUToLLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, Value thread); + +LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op, + triton::nvidia_gpu::DotAsyncOp::Adaptor adaptor, + TritonGPUToLLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + Value thread); struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< @@ -32,14 +47,15 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); // D = A * B + C Value A = op.getA(); Value D = op.getResult(); // Here we assume the DotOp's operands always comes from shared memory. - auto AShape = A.getType().cast().getShape(); + auto AShapePerCTA = getShapePerCTA(A.getType()); size_t reduceAxis = 1; - unsigned K = AShape[reduceAxis]; + unsigned K = AShapePerCTA[reduceAxis]; bool isOuter = K == 1; MmaEncodingAttr mmaLayout = D.getType() @@ -49,8 +65,14 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { if (!isOuter && mmaLayout && supportMMA(op, mmaLayout.getVersionMajor())) { if (mmaLayout.isVolta()) return convertMMA884(op, adaptor, getTypeConverter(), rewriter); + if (mmaLayout.isTuring()) + return convertMMA1688(op, adaptor, getTypeConverter(), rewriter); if (mmaLayout.isAmpere()) return convertMMA16816(op, adaptor, getTypeConverter(), rewriter); + if (mmaLayout.isHopper()) + return convertWGMMA(op, adaptor, getTypeConverter(), rewriter, + getThreadId(rewriter, loc)); + llvm::report_fatal_error( "Unsupported MMA kind found when converting DotOp to LLVM."); } @@ -60,7 +82,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { .cast() .getEncoding() .dyn_cast(); - if (!isOuter && mfmaLayout && supportMFMA(op)) { + if (!isOuter && mfmaLayout && supportMFMA(op, mfmaLayout.getNonKDim())) { return convertMFMA(op, adaptor, getTypeConverter(), rewriter); } #endif @@ -76,9 +98,68 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { } }; +struct DotAsyncOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::DotAsyncOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::DotAsyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + // D = A * B + C + Value A = op.getA(); + Value D = op.getResult(); + + // Here we assume the DotOp's operands always comes from shared memory. + auto AShapePerCTA = getShapePerCTA(A.getType()); + size_t reduceAxis = 1; + unsigned K = AShapePerCTA[reduceAxis]; + bool isOuter = K == 1; + + MmaEncodingAttr mmaLayout = D.getType() + .cast() + .getEncoding() + .dyn_cast(); + if (!isOuter && mmaLayout && + supportMMA(op.getOperand(0), mmaLayout.getVersionMajor())) { + if (mmaLayout.isHopper()) { + return convertAsyncWGMMA(op, adaptor, getTypeConverter(), rewriter, + getThreadId(rewriter, loc)); + } + + llvm::report_fatal_error( + "Unsupported MMA kind found when converting DotAsyncOp to LLVM."); + } + + llvm::report_fatal_error( + "Unsupported DotAsyncOp found when converting TritonGPU to LLVM."); + } +}; + +struct DotWaitOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::DotWaitOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::DotWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pendings = op.getPendings(); + rewriter.create(op.getLoc(), pendings); + + // Safe to remove the op since it doesn't have any return value. + rewriter.eraseOp(op); + return success(); + } +}; + void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, + RewritePatternSet &patterns, int numWarps, + ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, PatternBenefit benefit) { patterns.add(typeConverter, allocation, benefit); + patterns.add(typeConverter, allocation, benefit); + patterns.add(typeConverter, allocation, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.h index 92ad51f38ab1..6d457cffec19 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.h @@ -7,7 +7,8 @@ using namespace mlir; using namespace mlir::triton; void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, + RewritePatternSet &patterns, int numWarps, + ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, PatternBenefit benefit); diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp index a7d75e677441..775c2e4a1ec0 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -5,19 +5,20 @@ using namespace mlir; using namespace mlir::triton; using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::MmaEncodingAttr; using ValueTableFMA = std::map, Value>; static ValueTableFMA getValueTableFromStructFMA( - Value val, int K, int n0, int shapePerCTA, int sizePerThread, + Value val, int K, int n0, int shapePerCTATile, int sizePerThread, ConversionPatternRewriter &rewriter, Location loc, TritonGPUToLLVMTypeConverter *typeConverter, Type type) { ValueTableFMA res; auto elems = typeConverter->unpackLLElements(loc, val, rewriter, type); int index = 0; for (unsigned k = 0; k < K; ++k) { - for (unsigned m = 0; m < n0; m += shapePerCTA) + for (unsigned m = 0; m < n0; m += shapePerCTATile) for (unsigned mm = 0; mm < sizePerThread; ++mm) { res[{m + mm, k}] = elems[index++]; } @@ -40,8 +41,8 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, auto bTensorTy = B.getType().cast(); auto dTensorTy = D.getType().cast(); - auto aShape = aTensorTy.getShape(); - auto bShape = bTensorTy.getShape(); + auto aShapePerCTA = getShapePerCTA(aTensorTy); + auto bShapePerCTA = getShapePerCTA(bTensorTy); BlockedEncodingAttr dLayout = dTensorTy.getEncoding().cast(); @@ -53,41 +54,42 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, Value llB = adaptor.getB(); auto sizePerThread = getSizePerThread(dLayout); - auto shapePerCTA = getShapePerCTA(dLayout); + auto shapePerCTATile = getShapePerCTATile(dLayout); - int K = aShape[1]; - int M = aShape[0]; - int N = bShape[1]; + int K = aShapePerCTA[1]; + int M = aShapePerCTA[0]; + int N = bShapePerCTA[1]; - int mShapePerCTA = - order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; + int mShapePerCTATile = + order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; int mSizePerThread = order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - int nShapePerCTA = - order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; + int nShapePerCTATile = + order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; int nSizePerThread = order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; auto has = - getValueTableFromStructFMA(llA, K, M, mShapePerCTA, mSizePerThread, + getValueTableFromStructFMA(llA, K, M, mShapePerCTATile, mSizePerThread, rewriter, loc, typeConverter, aTensorTy); auto hbs = - getValueTableFromStructFMA(llB, K, N, nShapePerCTA, nSizePerThread, + getValueTableFromStructFMA(llB, K, N, nShapePerCTATile, nSizePerThread, rewriter, loc, typeConverter, bTensorTy); SmallVector ret = cc; bool isCRow = order[0] == 1; for (unsigned k = 0; k < K; k++) { - for (unsigned m = 0; m < M; m += mShapePerCTA) - for (unsigned n = 0; n < N; n += nShapePerCTA) + for (unsigned m = 0; m < M; m += mShapePerCTATile) + for (unsigned n = 0; n < N; n += nShapePerCTATile) for (unsigned mm = 0; mm < mSizePerThread; ++mm) for (unsigned nn = 0; nn < nSizePerThread; ++nn) { - int mIdx = m / mShapePerCTA * mSizePerThread + mm; - int nIdx = n / nShapePerCTA * nSizePerThread + nn; + int mIdx = m / mShapePerCTATile * mSizePerThread + mm; + int nIdx = n / nShapePerCTATile * nSizePerThread + nn; - int z = isCRow ? mIdx * N / nShapePerCTA * mSizePerThread + nIdx - : nIdx * M / mShapePerCTA * nSizePerThread + mIdx; + int z = isCRow + ? mIdx * N / nShapePerCTATile * mSizePerThread + nIdx + : nIdx * M / mShapePerCTATile * nSizePerThread + mIdx; ret[z] = rewriter.create(loc, has[{m + mm, k}], hbs[{n + nn, k}], ret[z]); } diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp index 047e1c6c1392..c0edc4080c9b 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -25,6 +25,11 @@ enum class MatrixCoreType : uint8_t { NOT_APPLICABLE, }; +struct MFMAInstrDescr { + MatrixCoreType coreType; + unsigned size; +}; + using ValueTable = std::map, Value>; struct DotOpMFMAConversionHelper { @@ -48,32 +53,65 @@ struct DotOpMFMAConversionHelper { return rewriter.create(loc, i32_ty, tid); } - Value generateMFMAOp(MatrixCoreType mfmaTy, Value valA, Value valB, + Value generateMFMAOp(MFMAInstrDescr mfmaDescr, Value valA, Value valB, Value valC) const { auto resType = valC.getType(); Value zeroFlag = i32_val(0); - switch (mfmaTy) { + switch (mfmaDescr.coreType) { case MatrixCoreType::FP32_FP16_FP16_FP32: - return rewriter.create( - loc, TypeRange{resType}, - ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + if (mfmaDescr.size == 16) { + return rewriter.create( + loc, TypeRange{resType}, + ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + } else { + return rewriter.create( + loc, TypeRange{resType}, + ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + } case MatrixCoreType::FP32_BF16_BF16_FP32: - return rewriter.create( - loc, TypeRange{resType}, - ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + if (mfmaDescr.size == 16) { + return rewriter.create( + loc, TypeRange{resType}, + ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + } else { + return rewriter.create( + loc, TypeRange{resType}, + ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + } case MatrixCoreType::FP32_BF16_BF16_FP32_1K: - return rewriter.create( - loc, TypeRange{resType}, - ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + if (mfmaDescr.size == 16) { + return rewriter.create( + loc, TypeRange{resType}, + ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + } else { + assert(mfmaDescr.size == 32); + return rewriter.create( + loc, TypeRange{resType}, + ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + } case MatrixCoreType::FP32_FP32_FP32_FP32: - return rewriter.create( - loc, TypeRange{resType}, - ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + if (mfmaDescr.size == 16) { + return rewriter.create( + loc, TypeRange{resType}, + ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + } else { + assert(mfmaDescr.size == 32); + return rewriter.create( + loc, TypeRange{resType}, + ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + } case MatrixCoreType::INT32_INT8_INT8_INT32: - return rewriter.create( - loc, TypeRange{resType}, - ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + if (mfmaDescr.size == 16) { + return rewriter.create( + loc, TypeRange{resType}, + ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + } else { + return rewriter.create( + loc, TypeRange{resType}, + ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + } case MatrixCoreType::FP64_FP64_FP64_FP64: + assert(mfmaDescr.size == 16); return rewriter.create( loc, TypeRange{resType}, ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); @@ -86,16 +124,22 @@ struct DotOpMFMAConversionHelper { auto aOperandTy = op.getA().getType(); auto tensorTy = aOperandTy.cast(); auto elemTy = tensorTy.getElementType(); + auto dotOpEncoding = tensorTy.getEncoding().cast(); + auto mfmaEncoding = dotOpEncoding.getParent().cast(); if (elemTy.isF16()) return MatrixCoreType::FP32_FP16_FP16_FP32; if (elemTy.isF32()) return MatrixCoreType::FP32_FP32_FP32_FP32; if (elemTy.isBF16()) { - auto dotOpEncoding = tensorTy.getEncoding().cast(); - if (dotOpEncoding.getKWidth() == 8) + auto nonKDim = mfmaEncoding.getNonKDim(); + auto kWidth = dotOpEncoding.getKWidth(); + if ((nonKDim == 32 && kWidth == 4) || (nonKDim == 16 && kWidth == 4)) { return MatrixCoreType::FP32_BF16_BF16_FP32_1K; - else + } else { + assert((nonKDim == 32 && kWidth == 2) || + (nonKDim == 16 && kWidth == 2)); return MatrixCoreType::FP32_BF16_BF16_FP32; + } } if (elemTy.isInteger(8)) return MatrixCoreType::INT32_INT8_INT8_INT32; @@ -104,11 +148,21 @@ struct DotOpMFMAConversionHelper { return MatrixCoreType::NOT_APPLICABLE; } + static MFMAInstrDescr getMatrixInstrDescr(DotOp op) { + MFMAInstrDescr descr; + auto tensorTy = op.getD().getType().cast(); + auto encoding = tensorTy.getEncoding().cast(); + descr.coreType = getMatrixCoreTypeFromDot(op); + descr.size = encoding.getNonKDim(); + return descr; + } + // Conduct the Dot conversion. LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor) const { auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); - assert(mfmaLayout.getNonKDim() == 32); - auto mfmaTy = getMatrixCoreTypeFromDot(op); + auto nonKDim = mfmaLayout.getNonKDim(); + assert(nonKDim == 32 || nonKDim == 16); + auto mfmaInstrDescr = getMatrixInstrDescr(op); Value a = op.getA(); Value b = op.getB(); @@ -142,22 +196,29 @@ struct DotOpMFMAConversionHelper { auto fc = typeConverter->unpackLLElements(loc, loadedC, rewriter, dstElemTy); - auto vecTy = vec_ty(dstElemTy, 16); + unsigned warpSize = triton::gpu::getWarpSize(mfmaLayout); + // compute number of output elements that each thread holds for one MFMA + // instruction + auto elemsPerVec = nonKDim * nonKDim / warpSize; + + auto vecTy = vec_ty(dstElemTy, elemsPerVec); for (int m = 0; m < numRepM; ++m) { for (int n = 0; n < numRepN; ++n) { Value acc = undef(vecTy); - for (unsigned v = 0; v < 16; ++v) { - acc = insert_element(vecTy, acc, fc[m * numRepN * 16 + n * 16 + v], - i32_val(v)); + for (unsigned v = 0; v < elemsPerVec; ++v) { + acc = insert_element( + vecTy, acc, fc[m * numRepN * elemsPerVec + n * elemsPerVec + v], + i32_val(v)); } for (size_t k = 0; k < numRepK; k++) { - acc = mfmaLayout.getIsTransposed() - ? generateMFMAOp(mfmaTy, hb[{n, k}], ha[{m, k}], acc) - : generateMFMAOp(mfmaTy, ha[{m, k}], hb[{n, k}], acc); + acc = + mfmaLayout.getIsTransposed() + ? generateMFMAOp(mfmaInstrDescr, hb[{n, k}], ha[{m, k}], acc) + : generateMFMAOp(mfmaInstrDescr, ha[{m, k}], hb[{n, k}], acc); } - for (unsigned v = 0; v < 16; ++v) { - fc[m * numRepN * 16 + n * 16 + v] = + for (unsigned v = 0; v < elemsPerVec; ++v) { + fc[m * numRepN * elemsPerVec + n * elemsPerVec + v] = extract_element(dstElemTy, acc, i32_val(v)); } } diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 726b87d6e2b2..bd222f9e6a72 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -141,7 +141,15 @@ TensorCoreType getMmaType(triton::DotOp op) { return TensorCoreType::NOT_APPLICABLE; } -inline static const std::map mmaInstrPtx = { +inline static const std::map mmaInstrPtxTuring = { + {TensorCoreType::FP32_FP16_FP16_FP32, + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"}, + + {TensorCoreType::FP16_FP16_FP16_FP16, + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16"}, +}; + +inline static const std::map mmaInstrPtxAmpere = { {TensorCoreType::FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"}, {TensorCoreType::FP32_BF16_BF16_FP32, @@ -164,22 +172,23 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, Value a, Value b, Value c, Value d, Value loadedA, Value loadedB, Value loadedC, DotOp op, - DotOpAdaptor adaptor) { + DotOpAdaptor adaptor, bool isTuring) { MLIRContext *ctx = c.getContext(); auto aTensorTy = a.getType().cast(); auto bTensorTy = b.getType().cast(); auto dTensorTy = d.getType().cast(); - SmallVector aShape(aTensorTy.getShape().begin(), - aTensorTy.getShape().end()); - auto dShape = dTensorTy.getShape(); + auto aShapePerCTA = triton::gpu::getShapePerCTA(aTensorTy); + auto bShapePerCTA = triton::gpu::getShapePerCTA(bTensorTy); + auto dShapePerCTA = triton::gpu::getShapePerCTA(dTensorTy); + int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth(); auto repA = aTensorTy.getEncoding().cast().getMMAv2Rep( - aTensorTy.getShape(), bitwidth); + aShapePerCTA, bitwidth); auto repB = bTensorTy.getEncoding().cast().getMMAv2Rep( - bTensorTy.getShape(), bitwidth); + bShapePerCTA, bitwidth); assert(repA[1] == repB[0]); int repM = repA[0], repN = repB[1], repK = repA[1]; @@ -196,23 +205,18 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, auto mmaType = getMmaType(op); + const auto &mmaInstructions = + isTuring ? mmaInstrPtxTuring : mmaInstrPtxAmpere; + auto callMma = [&](unsigned m, unsigned n, unsigned k) { unsigned colsPerThread = repN * 2; PTXBuilder builder; - auto &mma = *builder.create(mmaInstrPtx.at(mmaType)); + auto &mma = *builder.create(mmaInstructions.at(mmaType)); // using =r for float32 works but leads to less readable ptx. bool isIntMMA = dTensorTy.getElementType().isInteger(32); bool isAccF16 = dTensorTy.getElementType().isF16(); auto retArgs = builder.newListOperand(numMmaRets, isIntMMA || isAccF16 ? "=r" : "=f"); - auto aArgs = builder.newListOperand({ - {ha[{m, k}], "r"}, - {ha[{m + 1, k}], "r"}, - {ha[{m, k + 1}], "r"}, - {ha[{m + 1, k + 1}], "r"}, - }); - auto bArgs = - builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}}); auto cArgs = builder.newListOperand(); for (int i = 0; i < numMmaRets; ++i) { cArgs->listAppend(builder.newOperand( @@ -221,7 +225,32 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, // reuse the output registers } - mma(retArgs, aArgs, bArgs, cArgs); + if (isTuring) { + auto aArgs1 = builder.newListOperand({ + {ha[{m, k}], "r"}, + {ha[{m + 1, k}], "r"}, + }); + auto bArgs1 = builder.newListOperand({ + {hb[{n, k}], "r"}, + }); + auto aArgs2 = builder.newListOperand({ + {ha[{m, k + 1}], "r"}, + {ha[{m + 1, k + 1}], "r"}, + }); + auto bArgs2 = builder.newListOperand({{hb[{n, k + 1}], "r"}}); + mma(retArgs, aArgs1, bArgs1, cArgs); + mma(retArgs, aArgs2, bArgs2, cArgs); + } else { + auto aArgs = builder.newListOperand({ + {ha[{m, k}], "r"}, + {ha[{m + 1, k}], "r"}, + {ha[{m, k + 1}], "r"}, + {ha[{m + 1, k + 1}], "r"}, + }); + auto bArgs = + builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}}); + mma(retArgs, aArgs, bArgs, cArgs); + } Value mmaOut = builder.launch(rewriter, loc, getMmaRetType(mmaType, op.getContext())); @@ -258,10 +287,9 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, return success(); } -// Convert to mma.m16n8k16 -LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor, - TritonGPUToLLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) { +LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, + TritonGPUToLLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, bool isTuring) { auto loc = op.getLoc(); auto mmaLayout = op.getResult() .getType() @@ -287,5 +315,19 @@ LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor, loadC(op.getC(), adaptor.getC(), typeConverter, op.getLoc(), rewriter); return convertDot(typeConverter, rewriter, op.getLoc(), A, B, C, op.getD(), - loadedA, loadedB, loadedC, op, adaptor); + loadedA, loadedB, loadedC, op, adaptor, isTuring); +} + +// Convert to mma.m16n8k8 +LogicalResult convertMMA1688(triton::DotOp op, triton::DotOp::Adaptor adaptor, + TritonGPUToLLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + return convertMMA(op, adaptor, typeConverter, rewriter, true /*isTuring*/); +} + +// Convert to mma.m16n8k16 +LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor, + TritonGPUToLLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + return convertMMA(op, adaptor, typeConverter, rewriter, false /*isTuring*/); } diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp new file mode 100644 index 000000000000..fbccdeefcfb0 --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -0,0 +1,431 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "DotOpToLLVM.h" +#include "Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getShapePerCTATile; +using ::mlir::triton::gpu::MmaEncodingAttr; +using ::mlir::triton::gpu::SharedEncodingAttr; + +triton::nvgpu::WGMMAEltType getMmaRetType(Value d) { + auto dTy = d.getType().cast().getElementType(); + if (dTy.isF32()) { + return triton::nvgpu::WGMMAEltType::f32; + } else if (dTy.isF16()) { + return triton::nvgpu::WGMMAEltType::f16; + } else if (dTy.isInteger(32)) { + return triton::nvgpu::WGMMAEltType::s32; + } else { + llvm::report_fatal_error("Unsupported mma result type found"); + } +} + +triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) { + auto aTy = a.getType().cast().getElementType(); + if (aTy.isF16()) { + return triton::nvgpu::WGMMAEltType::f16; + } else if (aTy.isBF16()) { + return triton::nvgpu::WGMMAEltType::bf16; + } else if (aTy.isF32() && allowTF32) { + return triton::nvgpu::WGMMAEltType::tf32; + } else if (aTy.isInteger(8)) { + return triton::nvgpu::WGMMAEltType::s8; + } else if (aTy.isFloat8E5M2()) { + return triton::nvgpu::WGMMAEltType::e5m2; + } else if (aTy.isFloat8E4M3FNUZ()) { + return triton::nvgpu::WGMMAEltType::e4m3; + } else { + llvm::report_fatal_error("Unsupported mma operand type found"); + } +} + +mlir::triton::nvgpu::WGMMADescMode +getModeFromLayout(const SharedEncodingAttr &layout, uint32_t widthInByte) { + int perPhase = layout.getPerPhase(); + int maxPhase = layout.getMaxPhase(); + uint32_t swizzlingByteWidth = 0; + + mlir::triton::nvgpu::WGMMADescMode mode; + if (perPhase == 4 && maxPhase == 2) { + mode = mlir::triton::nvgpu::WGMMADescMode::swizzle32; + swizzlingByteWidth = 32; + } else if (perPhase == 2 && maxPhase == 4) { + mode = mlir::triton::nvgpu::WGMMADescMode::swizzle64; + swizzlingByteWidth = 64; + } else if (perPhase == 1 && maxPhase == 8) { + mode = mlir::triton::nvgpu::WGMMADescMode::swizzle128; + swizzlingByteWidth = 128; + } else { + llvm::report_fatal_error("Unsupported shared layout."); + } + + // TODO[biaow]: remove it once we support swizzling size larger than matrix + // width, which requires padding the matrix width to the swizzling size when + // allocating shared memory. + assert(swizzlingByteWidth <= widthInByte && + "swizzling size larger than matrix width is not supported."); + return mode; +} + +class DotOpMmaV3SmemLoader { +public: + DotOpMmaV3SmemLoader(Value tensor, const SharedMemoryObject &smemObj, + SmallVector shape, Value warpId, + unsigned int dimWpt, bool trans, + SmallVector instrShape, + ConversionPatternRewriter &rewriter, Location loc) + : base(smemObj.base), shape(shape), warpId(warpId), dimWpt(dimWpt), + trans(trans), instrShape(instrShape), rewriter(rewriter), loc(loc) { + auto tensorTy = tensor.getType().cast(); + auto sharedLayout = tensorTy.getEncoding().cast(); + ord = sharedLayout.getOrder(); + const int perPhase = sharedLayout.getPerPhase(); + const int maxPhase = sharedLayout.getMaxPhase(); + elemBytes = tensorTy.getElementTypeBitWidth() / 8; + elemsPerSwizzlingRow = 128 / perPhase / elemBytes; + elemsPerSwizzlingRowVal = i32_val(elemsPerSwizzlingRow); + + uint32_t widthInByte = shape[ord[0]] * elemBytes; + mode = getModeFromLayout(sharedLayout, widthInByte); + + baseDesc = rewriter.create( + loc, i64_ty, base, i32_val(shape[ord[1]]), mode); + } + + Value smemLoad(int a, int b) { + Value k = i32_val(b * instrShape[1]); + Value m = add(i32_val(a * dimWpt * instrShape[0]), + mul(warpId, i32_val(instrShape[0]))); + if (trans) { + std::swap(k, m); + } + Value leading_offset = mul(udiv(k, elemsPerSwizzlingRowVal), + i32_val(shape[ord[1]] * elemsPerSwizzlingRow)); + Value stride_offset = mul(m, elemsPerSwizzlingRowVal); + Value offset = add(add(leading_offset, stride_offset), + urem(k, elemsPerSwizzlingRowVal)); + Value off1 = mul(i32_val(elemBytes), offset); + Value off_ = zext(i64_ty, udiv(off1, i32_val(16))); + + return add(baseDesc, off_); + } + +private: + Value base; + SmallVector shape; + Value warpId; + int dimWpt; + bool trans; + Value elemsPerSwizzlingRowVal; + mlir::triton::nvgpu::WGMMADescMode mode; + SmallVector instrShape; + ArrayRef ord; + ConversionPatternRewriter &rewriter; + Location loc; + int elemsPerSwizzlingRow; + int elemBytes; + Value baseDesc; +}; + +DotOpMmaV3SmemLoader loadA(TritonGPUToLLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, Location loc, + const MmaEncodingAttr &mmaEncoding, Value tensor, + const SharedMemoryObject &smemObj, Value thread) { + auto aTensorTy = tensor.getType().cast(); + auto aSharedLayout = aTensorTy.getEncoding().dyn_cast(); + assert(aSharedLayout && "only support load dot operand from shared."); + auto instrShape = mmaEncoding.getInstrShape(); + auto wpt = mmaEncoding.getWarpsPerCTA(); + auto aOrd = aSharedLayout.getOrder(); + bool transA = aOrd[0] == 0; + auto shapePerCTA = getShapePerCTA(aTensorTy); + + int numRepM = ceil(shapePerCTA[0], instrShape[0] * wpt[0]); + int numRepK = ceil(shapePerCTA[1], instrShape[2]); + + Value warp = udiv(thread, i32_val(32)); + Value warpM = urem(warp, i32_val(wpt[0])); + Value warpId = urem(warpM, i32_val(shapePerCTA[0] / instrShape[0])); + + return {tensor, + smemObj, + shapePerCTA, + warpId, + wpt[0], + transA, + {instrShape[0], instrShape[2]}, + rewriter, + loc}; +} + +DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, Location loc, + MmaEncodingAttr &mmaEncoding, Value tensor, + const SharedMemoryObject &smemObj, Value thread) { + auto bTensorTy = tensor.getType().cast(); + auto bSharedLayout = bTensorTy.getEncoding().cast(); + assert(bSharedLayout && "only support load B from shared."); + auto instrShape = mmaEncoding.getInstrShape(); + auto wpt = mmaEncoding.getWarpsPerCTA(); + auto bOrd = bSharedLayout.getOrder(); + bool transB = bOrd[0] == 1; + auto shapePerCTA = triton::gpu::getShapePerCTA(bTensorTy); + + int numRepK = ceil(shapePerCTA[0], instrShape[2]); + int numRepN = ceil(shapePerCTA[1], instrShape[1] * wpt[1]); + + Value warp = udiv(thread, i32_val(32)); + Value warpMN = udiv(warp, i32_val(wpt[0])); + Value warpN = urem(warpMN, i32_val(wpt[1])); + Value warpId = urem(warpN, i32_val(shapePerCTA[1] / instrShape[1])); + + return {tensor, + smemObj, + shapePerCTA, + warpId, + wpt[1], + transB, + {instrShape[1], instrShape[2]}, + rewriter, + loc}; +} + +// Return a vector of Value of the accumulator start at startIndex and pack the +// values into 32bits in case the accumulator is fp16. +llvm::SmallVector loadC(ConversionPatternRewriter &rewriter, + Location loc, const SmallVector &elements, + int startIndex, int numElements) { + if (!elements[0].getType().isF16()) { + llvm::SmallVector mmaOut(numElements); + for (int i = 0; i < numElements; ++i) + mmaOut[i] = elements[startIndex + i]; + return mmaOut; + } + // For FP16 we need to pack accumulator into 32-bit integers. + llvm::SmallVector mmaOut(numElements / 2); + for (int i = 0; i < numElements / 2; ++i) { + Value a0 = elements[startIndex + 2 * i]; + Value a1 = elements[startIndex + 2 * i + 1]; + Type cPackTy = vec_ty(rewriter.getF16Type(), 2); + Value pack = rewriter.create(loc, cPackTy); + pack = insert_element(cPackTy, pack, a0, i32_val(0)); + pack = insert_element(cPackTy, pack, a1, i32_val(1)); + pack = bitcast(pack, rewriter.getIntegerType(32)); + mmaOut[i] = pack; + } + return mmaOut; +} + +// If the accumulator is fp16 unpack it from 32-bit integers. +SmallVector unpackAccumulator(ConversionPatternRewriter &rewriter, + Location loc, + const SmallVector &packed, + RankedTensorType tensorTy) { + if (!tensorTy.getElementType().isF16()) + return packed; + // For fp16 the accumualtor is pack into 32-bit integers so we need to unpack + // it. + SmallVector results; + for (Value elem : packed) { + elem = bitcast(elem, vec_ty(rewriter.getF16Type(), 2)); + results.push_back(extract_element(rewriter.getF16Type(), elem, i32_val(0))); + results.push_back(extract_element(rewriter.getF16Type(), elem, i32_val(1))); + } + return results; +} + +LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, Location loc, + Operation *op, Value a, Value b, Value c, Value d, + Value loadedA, Value loadedB, Value loadedC, + bool allowTF32, const SharedMemoryObject &smemObjA, + const SharedMemoryObject &smemObjB, bool sync, + Value thread) { + auto aTensorTy = a.getType().cast(); + auto bTensorTy = b.getType().cast(); + auto dTensorTy = d.getType().cast(); + auto aSharedLayout = aTensorTy.getEncoding().cast(); + auto bSharedLayout = bTensorTy.getEncoding().cast(); + auto mmaEncoding = dTensorTy.getEncoding().cast(); + auto aOrd = aSharedLayout.getOrder(); + auto bOrd = bSharedLayout.getOrder(); + bool transA = aOrd[0] == 0; + bool transB = bOrd[0] == 1; + auto dShapePerCTA = getShapePerCTA(dTensorTy); + auto instrShape = mmaEncoding.getInstrShape(); + auto accSize = 2 * (instrShape[1] / 4); + int M = 4 * instrShape[0]; + int N = instrShape[1]; + int K = instrShape[2]; + + auto shapePerCTATile = getShapePerCTATile(mmaEncoding); + int numRepM = ceil(dShapePerCTA[0], shapePerCTATile[0]); + int numRepN = ceil(dShapePerCTA[1], shapePerCTATile[1]); + int numRepK = ceil(aTensorTy.getShape()[1], instrShape[2]); + + DotOpMmaV3SmemLoader aLoader = + loadA(typeConverter, rewriter, loc, mmaEncoding, a, smemObjA, thread); + DotOpMmaV3SmemLoader bLoader = + loadB(typeConverter, rewriter, loc, mmaEncoding, b, smemObjB, thread); + + auto fc = typeConverter->unpackLLElements(loc, loadedC, rewriter, dTensorTy); + + triton::nvgpu::WGMMAEltType eltTypeC = getMmaRetType(d); + triton::nvgpu::WGMMAEltType eltTypeA = getMmaOperandType(a, allowTF32); + triton::nvgpu::WGMMAEltType eltTypeB = getMmaOperandType(b, allowTF32); + + triton::nvgpu::WGMMALayout layoutA = transA ? triton::nvgpu::WGMMALayout::col + : triton::nvgpu::WGMMALayout::row; + triton::nvgpu::WGMMALayout layoutB = transB ? triton::nvgpu::WGMMALayout::row + : triton::nvgpu::WGMMALayout::col; + + auto func = op->getParentOfType(); + int numTMADescs = + func->getAttr(kAttrNumTMALoadDescsName).cast().getInt(); + if (numTMADescs == 0) + rewriter.create(loc, 0); + rewriter.create(loc); + + SmallVector mmaResults; + for (int m = 0; m < numRepM; ++m) { + for (int n = 0; n < numRepN; ++n) { + llvm::SmallVector mmaOut = + loadC(rewriter, loc, fc, (m * numRepN + n) * accSize, accSize); + llvm::SmallVector elemTypes; + for (Value accEl : mmaOut) + elemTypes.push_back(accEl.getType()); + auto accTy = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); + Value d = typeConverter->packLLElements(loc, mmaOut, rewriter, accTy); + for (int k = 0; k < numRepK; ++k) { + auto a = aLoader.smemLoad(m, k); + auto b = bLoader.smemLoad(n, k); + ValueRange operands{a, b, d}; + d = rewriter.create(loc, accTy, a, b, d, M, N, + K, eltTypeC, eltTypeA, + eltTypeB, layoutA, layoutB); + } + auto acc = typeConverter->unpackLLElements(loc, d, rewriter, accTy); + for (int i = 0; i < acc.size(); ++i) { + mmaResults.push_back(acc[i]); + } + } + } + rewriter.create(loc); + + if (sync) + rewriter.create(loc, 0); + + SmallVector results = + unpackAccumulator(rewriter, loc, mmaResults, dTensorTy); + + // replace with new packed result + Type structTy = LLVM::LLVMStructType::getLiteral( + mmaEncoding.getContext(), + SmallVector(results.size(), dTensorTy.getElementType())); + auto res = typeConverter->packLLElements(loc, results, rewriter, structTy); + rewriter.replaceOp(op, res); + return success(); +} + +// Loading $c to registers, returns a Value. +Value loadC(Value tensor, Value llTensor) { + auto tensorTy = tensor.getType().cast(); + auto mmaEncoding = tensorTy.getEncoding().dyn_cast(); + assert(mmaEncoding && "Currently, we only support $c with a mma layout."); + auto instrShape = mmaEncoding.getInstrShape(); + auto wpt = mmaEncoding.getWarpsPerCTA(); + auto shapePerCTA = getShapePerCTA(tensorTy); + auto shapePerCTATile = getShapePerCTATile(mmaEncoding); + + int numRepM = ceil(shapePerCTA[0], shapePerCTATile[0]); + int numRepN = ceil(shapePerCTA[1], shapePerCTATile[1]); + + size_t fcSize = 2 * (instrShape[1] / 4) * numRepM * numRepN; + + auto structTy = llTensor.getType().cast(); + assert(structTy.getBody().size() == fcSize && + "DotOp's $c operand should pass the same number of values as $d in " + "mma layout."); + return llTensor; +} + +LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, + TritonGPUToLLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, Value thread) { + auto loc = op.getLoc(); + Value A = op.getA(); + Value B = op.getB(); + Value C = op.getC(); + auto ATensorTy = A.getType().cast(); + auto BTensorTy = B.getType().cast(); + + assert(ATensorTy.getEncoding().isa() && + BTensorTy.getEncoding().isa() && + "Both $a and %b should be Shared layout."); + + Value llA, llB, llC; + llA = adaptor.getA(); + llB = adaptor.getB(); + llC = loadC(C, adaptor.getC()); + + auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter); + auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter); + return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C, + op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA, + smemObjB, true, thread); +} + +LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op, + triton::nvidia_gpu::DotAsyncOp::Adaptor adaptor, + TritonGPUToLLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + Value thread) { + auto loc = op.getLoc(); + Value A = op.getA(); + Value B = op.getB(); + Value C = op.getC(); + auto ATensorTy = A.getType().cast(); + auto BTensorTy = B.getType().cast(); + + assert(ATensorTy.getEncoding().isa() && + BTensorTy.getEncoding().isa() && + "Both $a and %b should be Shared layout."); + + Value llA, llB, llC; + llA = adaptor.getA(); + llB = adaptor.getB(); + llC = loadC(C, adaptor.getC()); + + auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter); + auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter); + return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C, + op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA, + smemObjB, false, thread); +} diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 7c6e95affe63..6fb006bb8bcd 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -10,15 +10,14 @@ using ::mlir::triton::gpu::getTotalElemsPerThread; #ifdef USE_ROCM static SmallVector Fp16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter, - const Value &v0, const Value &v1, const Value &v2, - const Value &v3) { + const SmallVector &v) { auto fp16x2VecTy = vec_ty(f16_ty, 2); Value fp16x2Vec0 = undef(fp16x2VecTy); Value fp16x2Vec1 = undef(fp16x2VecTy); - fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0)); - fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1)); - fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0)); - fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1)); + fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0)); + fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1)); + fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[2], i32_val(0)); + fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[3], i32_val(1)); Value a0 = bitcast(fp16x2Vec0, i32_ty); Value a1 = bitcast(fp16x2Vec1, i32_ty); @@ -47,34 +46,30 @@ Fp16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter, const std::string Fp16_to_Fp8E5M2 = "{ \n" ".reg .b32 a<2>; \n" - "and.b32 a0, $1, 0x7fff7fff; \n" // a0 &= 0x7fff7fff - "and.b32 a1, $2, 0x7fff7fff; \n" // (strip sign) - "add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080 - "add.u32 a1, a1, 0x00800080; \n" // (round to nearest) - "lop3.b32 a0, $1, 0x80008000, a0, 0xea; \n" // a0 = a0|(0x80008000&in0) - "lop3.b32 a1, $2, 0x80008000, a1, 0xea; \n" // (restore sign) - "prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0 + "and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe + "and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit) + "add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080 + "add.u32 a1, a1, 0x00800080; \n" // (round to nearest) + "prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0 "}"; #endif #ifdef USE_ROCM static SmallVector Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, - const Value &v0, const Value &v1, const Value &v2, - const Value &v3) { + const SmallVector &v) { auto fp8x4VecTy = vec_ty(i8_ty, 4); Value a0 = undef(fp8x4VecTy); a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0)); - a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1)); + a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1)); a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2)); - a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3)); + a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3)); a0 = bitcast(a0, i32_ty); - Value a1 = undef(fp8x4VecTy); a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0)); - a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1)); + a1 = insert_element(fp8x4VecTy, a1, v[2], i32_val(1)); a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2)); - a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3)); + a1 = insert_element(fp8x4VecTy, a1, v[3], i32_val(3)); a1 = bitcast(a1, i32_ty); auto fp16x2VecTy = vec_ty(f16_ty, 2); @@ -97,21 +92,20 @@ const std::string Fp8E5M2_to_Fp16 = "{ \n" #ifdef USE_ROCM static SmallVector Fp8E5M2_to_Bf16(Location loc, ConversionPatternRewriter &rewriter, - const Value &v0, const Value &v1, const Value &v2, - const Value &v3) { + const SmallVector &v) { auto fp8x4VecTy = vec_ty(i8_ty, 4); Value a0 = undef(fp8x4VecTy); a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0)); - a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1)); + a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1)); a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2)); - a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3)); + a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3)); a0 = bitcast(a0, i32_ty); Value a1 = undef(fp8x4VecTy); a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0)); - a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1)); + a1 = insert_element(fp8x4VecTy, a1, v[2], i32_val(1)); a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2)); - a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3)); + a1 = insert_element(fp8x4VecTy, a1, v[3], i32_val(3)); a1 = bitcast(a1, i32_ty); Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff)); @@ -158,15 +152,14 @@ const std::string Fp8E5M2_to_Bf16 = #ifdef USE_ROCM static SmallVector Bf16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter, - const Value &v0, const Value &v1, const Value &v2, - const Value &v3) { + const SmallVector &v) { auto bf16x2VecTy = vec_ty(i16_ty, 2); Value bf16x2Vec0 = undef(bf16x2VecTy); Value bf16x2Vec1 = undef(bf16x2VecTy); - bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v0, i32_val(0)); - bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v1, i32_val(1)); - bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v2, i32_val(0)); - bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v3, i32_val(1)); + bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[0], i32_val(0)); + bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[1], i32_val(1)); + bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v[2], i32_val(0)); + bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v[3], i32_val(1)); bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty); bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty); @@ -279,21 +272,20 @@ const std::string Bf16_to_Fp8E5M2 = #ifdef USE_ROCM static SmallVector Fp8E4M3B15_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, - const Value &v0, const Value &v1, const Value &v2, - const Value &v3) { + const SmallVector &v) { auto fp8x4VecTy = vec_ty(i8_ty, 4); Value a0 = undef(fp8x4VecTy); a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0)); - a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1)); + a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1)); a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2)); - a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3)); + a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3)); a0 = bitcast(a0, i32_ty); Value a1 = undef(fp8x4VecTy); a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0)); - a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1)); + a1 = insert_element(fp8x4VecTy, a1, v[2], i32_val(1)); a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2)); - a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3)); + a1 = insert_element(fp8x4VecTy, a1, v[3], i32_val(3)); a1 = bitcast(a1, i32_ty); Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff)); @@ -319,30 +311,29 @@ Fp8E4M3B15_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, const std::string Fp8E4M3B15_to_Fp16 = "{ \n" ".reg .b32 a<2>, b<2>; \n" - "prmt.b32 a0, 0, $2, 0x5040; \n" - "prmt.b32 a1, 0, $2, 0x7060; \n" - "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" - "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" + "prmt.b32 a0, 0, $2, 0x5746; \n" + "and.b32 b0, a0, 0x7f007f00; \n" + "and.b32 b1, a0, 0x00ff00ff; \n" + "and.b32 a1, a0, 0x00800080; \n" "shr.b32 b0, b0, 1; \n" - "shr.b32 b1, b1, 1; \n" + "add.u32 b1, b1, a1; \n" "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" - "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" + "shl.b32 $1, b1, 7; \n" "} \n"; #endif #ifdef USE_ROCM static SmallVector Fp16_to_Fp8E4M3B15(Location loc, ConversionPatternRewriter &rewriter, - const Value &v0, const Value &v1, const Value &v2, - const Value &v3) { + const SmallVector &v) { auto fp16x2VecTy = vec_ty(f16_ty, 2); Value fp16x2Vec0 = undef(fp16x2VecTy); Value fp16x2Vec1 = undef(fp16x2VecTy); - fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0)); - fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1)); - fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0)); - fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1)); + fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0)); + fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1)); + fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[2], i32_val(0)); + fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[3], i32_val(1)); Value fp16x2VecMin = i32_val(0xBF80BF80); Value fp16x2VecMax = i32_val(0x3F803F80); @@ -376,26 +367,40 @@ Fp16_to_Fp8E4M3B15(Location loc, ConversionPatternRewriter &rewriter, }; } #else -const std::string Fp16_to_Fp8E4M3B15 = - "{ \n" - ".reg .b32 a<2>, b<2>; \n" - ".reg .b32 min_val, max_val; \n" - "mov.b32 min_val, 0xBF80BF80; \n" - "mov.b32 max_val, 0x3F803F80; \n" - "max.f16x2 $1, $1, min_val; \n" - "min.f16x2 $1, $1, max_val; \n" - "max.f16x2 $2, $2, min_val; \n" - "min.f16x2 $2, $2, max_val; \n" - "shl.b32 a0, $1, 1; \n" - "shl.b32 a1, $2, 1; \n" - "lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n" - "lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n" - "add.u32 a0, a0, 0x00800080; \n" - "add.u32 a1, a1, 0x00800080; \n" - "lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n" - "lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n" - "prmt.b32 $0, b0, b1, 0x7531; \n" - "}"; +const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) { + std::string ret; + ret += "{ \n" + ".reg .pred p<4>; \n" + ".reg .b32 a<2>, b<2>; \n" + ".reg .b16 c<4>; \n" + ".reg .b16 max_val_f16; \n" + ".reg .b32 max_val_f16x2; \n" + "mov.b16 max_val_f16, 0x3F80; \n" + "mov.b32 max_val_f16x2, 0x3F803F80; \n" + "and.b32 a0, $1, 0x7fff7fff; \n" + "and.b32 a1, $2, 0x7fff7fff; \n"; + if (has_minx2) + ret += "min.f16x2 a0, a0, max_val_f16x2; \n" + "min.f16x2 a1, a1, max_val_f16x2; \n"; + else + ret += "setp.lt.f16x2 p0|p1, a0, max_val_f16x2; \n" + "setp.lt.f16x2 p2|p3, a1, max_val_f16x2; \n" + "mov.b32 {c0, c1}, a0; \n" + "mov.b32 {c2, c3}, a1; \n" + "selp.b16 c0, c0, max_val_f16, p0; \n" + "selp.b16 c1, c1, max_val_f16, p1; \n" + "selp.b16 c2, c2, max_val_f16, p2; \n" + "selp.b16 c3, c3, max_val_f16, p3; \n" + "mov.b32 a0, {c0, c1}; \n" + "mov.b32 a1, {c2, c3}; \n"; + ret += "mad.lo.u32 a0, a0, 2, 0x00800080; \n" + "mad.lo.u32 a1, a1, 2, 0x00800080; \n" + "lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n" + "lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n" + "prmt.b32 $0, b0, b1, 0x7531; \n" + "}"; + return ret; +} #endif /* ----- FP8E4M3B15X4 ------ */ @@ -411,20 +416,41 @@ const std::string Fp16_to_Fp8E4M3B15 = #ifdef USE_ROCM static SmallVector Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, - const Value &v0, const Value &v1, const Value &v2, - const Value &v3) { - return {}; + const SmallVector &v) { + auto fp8x4VecTy = vec_ty(i8_ty, 4); + Value fp8x4Vec = undef(fp8x4VecTy); + fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v[0], i32_val(0)); + fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v[1], i32_val(1)); + fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v[2], i32_val(2)); + fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v[3], i32_val(3)); + fp8x4Vec = bitcast(fp8x4Vec, i32_ty); + + Value a0 = add(i32_ty, fp8x4Vec, fp8x4Vec); + Value a1 = shl(i32_ty, fp8x4Vec, i32_val(7)); + + Value fp16x2Vec0 = and_(i32_ty, a0, i32_val(0x80008000)); + fp16x2Vec0 = or_(i32_ty, fp16x2Vec0, and_(i32_ty, a1, i32_val(0x3f803f80)) ); + Value fp16x2Vec1 = and_(i32_ty, fp8x4Vec, i32_val(0xbf80bf80)); + + auto fp16x2VecTy = vec_ty(f16_ty, 2); + fp16x2Vec0 = bitcast(fp16x2Vec0, fp16x2VecTy); + fp16x2Vec1 = bitcast(fp16x2Vec1, fp16x2VecTy); + + return { extract_element(f16_ty, fp16x2Vec0, i32_val(0)), + extract_element(f16_ty, fp16x2Vec0, i32_val(1)), + extract_element(f16_ty, fp16x2Vec1, i32_val(0)), + extract_element(f16_ty, fp16x2Vec1, i32_val(1)) + }; } #else const std::string Fp8E4M3B15x4_to_Fp16 = "{ \n" ".reg .b32 a<2>; \n" - "shl.b32 a0, $2, 1; \n" + "add.u32 a0, $2, $2; \n" "shl.b32 a1, $2, 7; \n" "and.b32 $0, a0, 0x80008000; \n" "lop3.b32 $0, $0, a1, 0x3f803f80, 0xf8; \n" - "and.b32 $1, $2, 0x80008000; \n" - "lop3.b32 $1, $1, $2, 0x3f803f80, 0xf8; \n" + "and.b32 $1, $2, 0xbf80bf80; \n" "}"; #endif @@ -438,9 +464,34 @@ const std::string Fp8E4M3B15x4_to_Fp16 = #ifdef USE_ROCM static SmallVector Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter, - const Value &v0, const Value &v1, const Value &v2, - const Value &v3) { - return {}; + const SmallVector &v) { + auto fp16x2VecTy = vec_ty(f16_ty, 2); + Value fp16x2Vec0 = undef(fp16x2VecTy); + Value fp16x2Vec1 = undef(fp16x2VecTy); + + fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0)); + fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1)); + fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[2], i32_val(0)); + fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[3], i32_val(1)); + + fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty); + fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty); + + Value a0 = lshr(i32_ty, fp16x2Vec0, i32_val(1)); + Value a1 = lshr(i32_ty, fp16x2Vec0, i32_val(7)); + + Value fp8x4Vec = and_(i32_ty, a0, i32_val(0x40004000)); + fp8x4Vec = or_(i32_ty, fp8x4Vec, and_(i32_ty, a1, i32_val(0x007f007f)) ); + fp8x4Vec = or_(i32_ty, fp8x4Vec, and_(i32_ty, fp16x2Vec1, i32_val(0xbf80bf80)) ); + + auto fp8x4VecTy = vec_ty(i8_ty, 4); + fp8x4Vec = bitcast(fp8x4Vec, fp8x4VecTy); + + return {extract_element(i8_ty, fp8x4Vec, i32_val(0)), + extract_element(i8_ty, fp8x4Vec, i32_val(1)), + extract_element(i8_ty, fp8x4Vec, i32_val(2)), + extract_element(i8_ty, fp8x4Vec, i32_val(3)) + }; } #else const std::string Fp16_to_Fp8E4M3B15x4 = @@ -450,8 +501,7 @@ const std::string Fp16_to_Fp8E4M3B15x4 = "shr.b32 a1, $1, 7; \n" "and.b32 $0, a0, 0x40004000; \n" "lop3.b32 $0, $0, a1, 0x007f007f, 0xf8; \n" - "lop3.b32 $0, $0, $2, 0x80008000, 0xf8; \n" - "lop3.b32 $0, $0, $2, 0x3f803f80, 0xf8; \n" + "lop3.b32 $0, $0, $2, 0xbf80bf80, 0xf8; \n" "}"; #endif @@ -464,43 +514,28 @@ const std::string Fp16_to_Fp8E4M3B15x4 = #ifdef USE_ROCM static SmallVector Fp8E4M3_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, - const Value &v0, const Value &v1, const Value &v2, - const Value &v3) { + const SmallVector &v) { auto fp8x4VecTy = vec_ty(i8_ty, 4); Value a0 = undef(fp8x4VecTy); a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0)); - a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1)); + a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1)); a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2)); - a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3)); + a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3)); a0 = bitcast(a0, i32_ty); - Value a1 = undef(fp8x4VecTy); - a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0)); - a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1)); - a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2)); - a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3)); - a1 = bitcast(a1, i32_ty); - Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff)); - Value b1 = and_(i32_ty, a1, i32_val(0x7fff7fff)); b0 = lshr(i32_ty, b0, i32_val(1)); - b1 = lshr(i32_ty, b1, i32_val(1)); b0 = add(i32_ty, b0, i32_val(0x20002000)); - b1 = add(i32_ty, b1, i32_val(0x20002000)); b0 = or_( i32_ty, b0, and_(i32_ty, a0, i32_val(0x80008000)) ); - b1 = or_( i32_ty, b1, and_(i32_ty, a1, i32_val(0x80008000)) ); auto fp16x2VecTy = vec_ty(f16_ty, 2); auto fp16x2Vec0 = bitcast(b0, fp16x2VecTy); - auto fp16x2Vec1 = bitcast(b1, fp16x2VecTy); return { extract_element(f16_ty, fp16x2Vec0, i32_val(0)), - extract_element(f16_ty, fp16x2Vec0, i32_val(1)), - extract_element(f16_ty, fp16x2Vec1, i32_val(0)), - extract_element(f16_ty, fp16x2Vec1, i32_val(1)) + extract_element(f16_ty, fp16x2Vec0, i32_val(1)) }; } #else @@ -525,39 +560,26 @@ const std::string Fp8E4M3_to_Fp16 = #ifdef USE_ROCM static SmallVector Fp16_to_Fp8E4M3(Location loc, ConversionPatternRewriter &rewriter, - const Value &v0, const Value &v1, const Value &v2, - const Value &v3) { + const SmallVector &v) { auto fp16x2VecTy = vec_ty(f16_ty, 2); Value fp16x2Vec0 = undef(fp16x2VecTy); - Value fp16x2Vec1 = undef(fp16x2VecTy); - fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0)); - fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1)); - fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0)); - fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1)); + fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0)); + fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1)); fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty); - fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty); fp16x2Vec0 = sub(i32_ty, fp16x2Vec0, i32_val(0x20002000)); - fp16x2Vec1 = sub(i32_ty, fp16x2Vec1, i32_val(0x20002000)); Value a0 = shl(i32_ty, fp16x2Vec0, i32_val(1)); - Value a1 = shl(i32_ty, fp16x2Vec1, i32_val(1)); a0 = and_(i32_ty, a0, i32_val(0x7fff7fff)); - a1 = and_(i32_ty, a1, i32_val(0x7fff7fff)); a0 = add(i32_ty, a0, i32_val(0x00800080)); - a1 = add(i32_ty, a1, i32_val(0x00800080)); Value b0 = or_( i32_ty, and_(i32_ty, fp16x2Vec0, i32_val(0x80008000)), a0 ); - Value b1 = or_( i32_ty, and_(i32_ty, fp16x2Vec1, i32_val(0x80008000)), a1 ); auto fp8x4VecTy = vec_ty(i8_ty, 4); b0 = bitcast(b0, fp8x4VecTy); - b1 = bitcast(b1, fp8x4VecTy); return {extract_element(i8_ty, b0, i32_val(1)), - extract_element(i8_ty, b0, i32_val(3)), - extract_element(i8_ty, b1, i32_val(1)), - extract_element(i8_ty, b1, i32_val(3)) + extract_element(i8_ty, b0, i32_val(3)) }; } #else @@ -584,21 +606,20 @@ const std::string Fp16_to_Fp8E4M3 = #ifdef USE_ROCM static SmallVector Fp8E4M3_to_Bf16(Location loc, ConversionPatternRewriter &rewriter, - const Value &v0, const Value &v1, const Value &v2, - const Value &v3) { + const SmallVector &v) { auto fp8x4VecTy = vec_ty(i8_ty, 4); Value a0 = undef(fp8x4VecTy); a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0)); - a0 = insert_element(fp8x4VecTy, a0, v0, i32_val(1)); + a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1)); a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2)); - a0 = insert_element(fp8x4VecTy, a0, v1, i32_val(3)); + a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3)); a0 = bitcast(a0, i32_ty); Value a1 = undef(fp8x4VecTy); a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(0)); - a1 = insert_element(fp8x4VecTy, a1, v2, i32_val(1)); + a1 = insert_element(fp8x4VecTy, a1, v[2], i32_val(1)); a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2)); - a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3)); + a1 = insert_element(fp8x4VecTy, a1, v[3], i32_val(3)); a1 = bitcast(a1, i32_ty); Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff)); @@ -645,15 +666,14 @@ const std::string Fp8E4M3_to_Bf16 = #ifdef USE_ROCM static SmallVector Bf16_to_Fp8E4M3(Location loc, ConversionPatternRewriter &rewriter, - const Value &v0, const Value &v1, const Value &v2, - const Value &v3) { + const SmallVector &v) { auto bf16x2VecTy = vec_ty(i16_ty, 2); Value bf16x2Vec0 = undef(bf16x2VecTy); Value bf16x2Vec1 = undef(bf16x2VecTy); - bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v0, i32_val(0)); - bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v1, i32_val(1)); - bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v2, i32_val(0)); - bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v3, i32_val(1)); + bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[0], i32_val(0)); + bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[1], i32_val(1)); + bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v[2], i32_val(0)); + bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v[3], i32_val(1)); bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty); bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty); @@ -755,6 +775,15 @@ const std::string Bf16_to_Fp8E4M3 = "}"; #endif +// Fp8E4M3 (x2) -> Fp16 (x2) (packed) +const std::string Fp8E4M3Nv_to_Fp16 = "{ \n" + "cvt.rn.f16x2.e4m3x2 $0, $1; \n" + "}"; +// Fp16 (x2) -> Fp8E4M3 (x2) (packed) +const std::string Fp16_to_Fp8E4M3Nv = "{ \n" + "cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n" + "}"; + /* ----- Packed integer to BF16 ------ */ #ifndef USE_ROCM const std::string S8_to_Bf16 = @@ -853,6 +882,13 @@ static SmallVector reorderValues(const SmallVector &values, llvm_unreachable("unimplemented code path"); } +inline Type getElementType(Value value) { + auto type = value.getType(); + if (auto tensorType = type.dyn_cast()) + return tensorType.getElementType(); + return type; +} + inline SmallVector unpackI32(const SmallVector &inValues, Type srcTy, ConversionPatternRewriter &rewriter, @@ -862,8 +898,10 @@ inline SmallVector unpackI32(const SmallVector &inValues, if (!tensorTy) return inValues; auto encoding = tensorTy.getEncoding().dyn_cast(); - if (!(encoding && encoding.getParent().isa())) + if (!(encoding && (encoding.getParent().isa() or + encoding.getParent().isa()))) { return inValues; + } SmallVector outValues; for (auto v : inValues) { // cast i32 to appropriate eltType vector and extract elements @@ -902,40 +940,49 @@ inline SmallVector packI32(const SmallVector &inValues, } typedef std::function(Location, ConversionPatternRewriter &, - const Value &, const Value &, - const Value &, const Value &)> + const SmallVector &)> ConverterT; static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType, - Type outType) { + Type outType, + const int inVecWidthBits = 32, + const int outVecWidthBits = 32) { + + ConverterT converter = + [ptxAsm, inType, outType, inVecWidthBits, + outVecWidthBits](Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) -> SmallVector { + int numElements = v.size(); + assert(numElements == 4 || numElements == 2 && "invalid vector size"); - ConverterT converter = [ptxAsm, inType, outType]( - Location loc, ConversionPatternRewriter &rewriter, - const Value &v0, const Value &v1, const Value &v2, - const Value &v3) -> SmallVector { - SmallVector v = {v0, v1, v2, v3}; auto ctx = rewriter.getContext(); int inBitwidth = inType.getIntOrFloatBitWidth(); int outBitwidth = outType.getIntOrFloatBitWidth(); // first, we pack `v` into 32-bit ints - int inVecWidth = 32 / inBitwidth; + int inVecWidth = inVecWidthBits / inBitwidth; auto inVecTy = vec_ty(inType, inVecWidth); - SmallVector inPacked(4 / inVecWidth, undef(inVecTy)); - for (size_t i = 0; i < 4; i++) + SmallVector inPacked(numElements / inVecWidth, undef(inVecTy)); + for (size_t i = 0; i < numElements; i++) inPacked[i / inVecWidth] = insert_element( inVecTy, inPacked[i / inVecWidth], v[i], i32_val(i % inVecWidth)); for (size_t i = 0; i < inPacked.size(); i++) - inPacked[i] = bitcast(inPacked[i], i32_ty); + inPacked[i] = bitcast(inPacked[i], int_ty(inVecWidthBits)); // then, we run the provided inline PTX - int outVecWidth = 32 / outBitwidth; - int outNums = 4 / outVecWidth; + int outVecWidth = outVecWidthBits / outBitwidth; + int outNums = numElements / outVecWidth; PTXBuilder builder; SmallVector operands; - for (int i = 0; i < outNums; i++) - operands.push_back(builder.newOperand("=r")); - for (Value inVal : inPacked) - operands.push_back(builder.newOperand(inVal, "r")); + auto outConstriant = outVecWidthBits == 16 ? "=h" : "=r"; + auto inConstraint = inVecWidthBits == 16 ? "h" : "r"; + for (int i = 0; i < outNums; i++) { + operands.push_back(builder.newOperand(outConstriant)); + } + + for (Value inVal : inPacked) { + operands.push_back(builder.newOperand(inVal, inConstraint)); + } + auto &ptxOp = *builder.create(ptxAsm); ptxOp(operands, /*onlyAttachMLIRArgs=*/true); auto outVecTy = vec_ty(outType, outVecWidth); @@ -950,7 +997,7 @@ static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType, } // unpack the output SmallVector ret; - for (size_t i = 0; i < 4; i++) + for (size_t i = 0; i < numElements; i++) ret.push_back(extract_element(outType, outPacked[i / outVecWidth], i32_val(i % outVecWidth))); return ret; @@ -997,6 +1044,7 @@ class ElementwiseOpConversionBase Location loc = op->getLoc(); // element type auto resultElementTy = getElementTypeOrSelf(resultTy); + Type elemTy = this->getTypeConverter()->convertType(resultElementTy); SmallVector> allOperands; for (auto operand : adaptor.getOperands()) { @@ -1025,18 +1073,24 @@ class ElementwiseOpConversionBase } it += curr.size(); } + if (op->getNumOperands() > 0) { auto argTy = op->getOperand(0).getType(); resultVals = reorderValues(resultVals, argTy, resultTy); } resultVals = packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); + resultVals = this->getTypeConverter()->packMfmaOperand(resultVals, resultTy, rewriter, loc); + Value view = this->getTypeConverter()->packLLElements(loc, resultVals, rewriter, resultTy); rewriter.replaceOp(op, view); return success(); } + +private: + int computeCapability; }; template @@ -1049,11 +1103,6 @@ struct ElementwiseOpConversion using Base::Base; using OpAdaptor = typename Base::OpAdaptor; - explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) - : ElementwiseOpConversionBase( - typeConverter, benefit) {} - // An interface to support variant DestOp builder. SmallVector createDestOps(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, @@ -1070,6 +1119,11 @@ struct FpToFpOpConversion using ElementwiseOpConversionBase< triton::FpToFpOp, FpToFpOpConversion>::ElementwiseOpConversionBase; + explicit FpToFpOpConversion(TritonGPUToLLVMTypeConverter &typeConverter, + int computeCapability, PatternBenefit benefit = 1) + : ElementwiseOpConversionBase(typeConverter, benefit), + computeCapability(computeCapability) {} + static Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter, const Value &v) { @@ -1163,6 +1217,7 @@ struct FpToFpOpConversion auto F8E4M3B15TyID = TypeID::get(); auto F8E4M3TyID = TypeID::get(); auto F8E5M2TyID = TypeID::get(); + auto F8E4M3FNTyID = TypeID::get(); auto F16TyID = TypeID::get(); auto BF16TyID = TypeID::get(); auto F32TyID = TypeID::get(); @@ -1174,20 +1229,35 @@ struct FpToFpOpConversion #endif // F8 -> F16 {{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16}, + {{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16}, {{F8E4M3TyID, F16TyID}, Fp8E4M3_to_Fp16}, {{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16}, // F16 -> F8 +#ifdef USE_ROCM {{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15}, +#else + {{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)}, +#endif + {{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4}, {{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3}, {{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2}, // F8 -> BF16 - {{F8E4M3TyID, BF16TyID}, Fp8E4M3_to_Bf16}, {{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16}, // BF16 -> F8 - {{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3}, {{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2}, }; + int inVecWidthBits = 32; + int outVecWidthBits = 32; + if (srcTy.isFloat8E4M3FNUZ()) { + inVecWidthBits = 16; + outVecWidthBits = 32; + } + if (dstTy.isFloat8E4M3FNUZ()) { + inVecWidthBits = 32; + outVecWidthBits = 16; + } + std::pair key = {srcTy.getTypeID(), dstTy.getTypeID()}; if (srcMap.count(key) == 0) { llvm::errs() << "Unsupported conversion from " << srcTy << " to " << dstTy @@ -1197,9 +1267,17 @@ struct FpToFpOpConversion #ifdef USE_ROCM return srcMap.lookup(key); #else + if (computeCapability < 90 && + (srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) { + llvm::errs() << "Conversion from/to f8e4m3nv is only supported on " + "compute capability >= 90" + << "\n"; + llvm_unreachable(""); + } return makeConverterFromPtx(srcMap.lookup(key), getTypeConverter()->convertType(srcTy), - getTypeConverter()->convertType(dstTy)); + getTypeConverter()->convertType(dstTy), + inVecWidthBits, outVecWidthBits); #endif } @@ -1207,21 +1285,27 @@ struct FpToFpOpConversion ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, Location loc) const { - assert(operands.size() % 4 == 0 && - "FP8 casting only support tensors with 4-aligned sizes"); auto srcElementType = getElementType(op.getFrom()); auto dstElementType = getElementType(op.getResult()); + int numElements = 4; + if (srcElementType.isFloat8E4M3FNUZ() || + dstElementType.isFloat8E4M3FNUZ()) { + numElements = 2; + } + assert(operands.size() % numElements == 0 && + "FP8 casting only support tensors with aligned sizes"); bool isSrcFP32 = srcElementType.isF32(); bool isDstFP32 = dstElementType.isF32(); auto cvtFunc = getConversionFunc(isSrcFP32 ? f16_ty : srcElementType, isDstFP32 ? f16_ty : dstElementType); - SmallVector inVals = {operands[0][0], operands[1][0], operands[2][0], - operands[3][0]}; + SmallVector inVals; + for (unsigned i = 0; i < numElements; i++) { + inVals.push_back(operands[i][0]); + } if (isSrcFP32) for (Value &v : inVals) v = convertFp32ToFp16(loc, rewriter, v); - SmallVector outVals = - cvtFunc(loc, rewriter, inVals[0], inVals[1], inVals[2], inVals[3]); + SmallVector outVals = cvtFunc(loc, rewriter, inVals); assert(outVals.size() == inVals.size()); if (isDstFP32) for (Value &v : outVals) @@ -1229,6 +1313,9 @@ struct FpToFpOpConversion // Pack values return outVals; } + +private: + int computeCapability; }; template @@ -1331,15 +1418,16 @@ struct CmpFOpConversion } }; -template struct ExternElementwiseOpConversion - : public ElementwiseOpConversionBase> { - using Base = ElementwiseOpConversionBase>; + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; typedef typename Base::OpAdaptor OpAdaptor; - SmallVector createDestOps(T op, OpAdaptor adaptor, + SmallVector createDestOps(ExternElementwiseOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, Location loc) const { @@ -1360,8 +1448,9 @@ struct ExternElementwiseOpConversion return LLVM::LLVMFunctionType::get(resultType, operandTypes); } - LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter, T op, - StringRef funcName, Type funcType) const { + LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter, + ExternElementwiseOp op, StringRef funcName, + Type funcType) const { using LLVM::LLVMFuncOp; auto funcAttr = StringAttr::get(op->getContext(), funcName); @@ -1380,6 +1469,86 @@ struct ExternElementwiseOpConversion } }; +struct ElementwiseInlineAsmOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + typedef typename Base::OpAdaptor OpAdaptor; + + // If operand size is smaller than 32bits pack by groups of 32bits. + // Otherwise have separate inputs. + SmallVector packOperands(ElementwiseInlineAsmOp op, + MultipleOperandsRange operands, + ConversionPatternRewriter &rewriter, + Location loc) const { + SmallVector packedOperands; + unsigned numPackedElements = op.getPackedElement(); + for (int i = 0, e = op.getNumOperands(); i < e; i++) { + unsigned bitWidth = + getElementType(op.getOperand(i)).getIntOrFloatBitWidth(); + unsigned numElementPerReg = bitWidth < 32 ? 32 / bitWidth : 1; + numElementPerReg = std::min(numElementPerReg, numPackedElements); + for (int j = 0; j < numPackedElements; j += numElementPerReg) { + if (numElementPerReg == 1) { + packedOperands.push_back(operands[j][i]); + continue; + } + Type t = vec_ty( + getTypeConverter()->convertType(getElementType(op.getOperand(i))), + numElementPerReg); + Value packed = undef(t); + for (int k = 0; k < numElementPerReg; k++) { + packed = insert_element(packed, operands[j + k][i], i32_val(k)); + } + packedOperands.push_back(packed); + } + } + return packedOperands; + } + + SmallVector createDestOps(ElementwiseInlineAsmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + int numPackedElements = op.getPackedElement(); + if (operands.size() % numPackedElements != 0) + llvm::report_fatal_error("Inline asm op has more packed elements than " + "number of elements per thread."); + SmallVector packedOperands = + packOperands(op, operands, rewriter, loc); + Type dstType = + getTypeConverter()->convertType(getElementType(op.getResult())); + Type retType = dstType; + if (numPackedElements > 1) + retType = vec_ty(retType, numPackedElements); + Value result = rewriter + .create( + loc, retType, + packedOperands, // operands + op.getAsmString(), // asm_string + op.getConstraints(), // constraints + !op.getPure(), // has_side_effects + false, // is_align_stack + LLVM::AsmDialectAttr::get( + rewriter.getContext(), + LLVM::AsmDialect::AD_ATT), // asm_dialect + ArrayAttr() // operand_attrs + ) + ->getResult(0); + SmallVector results; + if (numPackedElements > 1) { + for (int i = 0; i < numPackedElements; i++) + results.push_back(extract_element(result, i32_val(i))); + } else { + results = {result}; + } + return results; + } +}; + struct FDivOpConversion : ElementwiseOpConversionBase { using Base = @@ -1527,9 +1696,8 @@ struct FSubOpConversion #ifdef USE_ROCM static SmallVector S8_to_Bf16(Location loc, ConversionPatternRewriter &rewriter, - const Value &v0, const Value &v1, const Value &v2, - const Value &v3) { - SmallVector inValues = {v0, v1, v2, v3}; + const SmallVector &v) { + SmallVector inValues = {v[0], v[1], v[2], v[3]}; SmallVector outValues = {}; for (Value inVal : inValues) { Value i32Val = sext(i32_ty, inVal); @@ -1566,14 +1734,16 @@ struct SIToFPOpConversion Type outElemTy = getElementType(op.getOut()); if (outElemTy.isBF16() && inElemTy.isInteger(8) && operands.size() >= 4) { #if USE_ROCM - auto outVals = S8_to_Bf16(loc, rewriter, operands[0][0], operands[1][0], - operands[2][0], operands[3][0]); + SmallVector inVals = {operands[0][0], operands[1][0], + operands[2][0], operands[3][0]}; + auto outVals = S8_to_Bf16(loc, rewriter, inVals); #else - auto cvtFunc = makeConverterFromPtx( - S8_to_Bf16, getTypeConverter()->convertType(inElemTy), - getTypeConverter()->convertType(outElemTy)); - auto outVals = cvtFunc(loc, rewriter, operands[0][0], operands[1][0], - operands[2][0], operands[3][0]); + auto cvtFunc = makeConverterFromPtx( + S8_to_Bf16, getTypeConverter()->convertType(inElemTy), + getTypeConverter()->convertType(outElemTy)); + SmallVector inVals = {operands[0][0], operands[1][0], + operands[2][0], operands[3][0]}; + auto outVals = cvtFunc(loc, rewriter, inVals); #endif assert(outVals.size() == 4); return outVals; @@ -1763,7 +1933,10 @@ struct IndexCastOpLowering void populateElementwiseOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit) { + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, + ModuleAllocation &allocation, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + int computeCapability, PatternBenefit benefit) { #define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \ patterns.add>(typeConverter, benefit); POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp) @@ -1787,7 +1960,11 @@ void populateElementwiseOpToLLVMPatterns( POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin + POPULATE_BINARY_OP(arith::MaxFOp, LLVM::MaxNumOp) // fmax POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin + POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax + POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin + POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax #undef POPULATE_BINARY_OP #define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ @@ -1823,16 +2000,171 @@ void populateElementwiseOpToLLVMPatterns( patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, computeCapability, benefit); - patterns.add>( - typeConverter, benefit); - patterns - .add>( - typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); // ExpOpConversionApprox will try using ex2.approx if the input type is // FP32. For other input types, ExpOpConversionApprox will return failure and // ElementwiseOpConversion defined below will call // __nv_expf for higher-precision calculation patterns.add(typeConverter, benefit); } + +struct FPExtOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + static bool isLegalOp(LLVM::FPExtOp op) { + auto retTy = op.getResult().getType(); + auto srcTy = op.getOperand().getType(); + if (retTy.isF32() && srcTy.isF16()) { + return false; + } + return true; + } + + SmallVector createDestOps(LLVM::FPExtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + return { + FpToFpOpConversion::convertFp16ToFp32(loc, rewriter, operands[0][0])}; + } +}; + +struct FPTruncOpConversion + : ElementwiseOpConversionBase { + using Base = + ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + static bool isLegalOp(LLVM::FPTruncOp op) { + auto retTy = op.getResult().getType(); + auto srcTy = op.getOperand().getType(); + if (retTy.isF16() && srcTy.isF32()) { + return false; + } + return true; + } + + SmallVector createDestOps(LLVM::FPTruncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + return { + FpToFpOpConversion::convertFp32ToFp16(loc, rewriter, operands[0][0])}; + } +}; + +struct TruncOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + static bool isLegalOp(LLVM::TruncOp op) { + auto retTy = op.getResult().getType(); + auto srcTy = op.getOperand().getType(); + if (retTy.isInteger(16) && srcTy.isInteger(32)) { + return false; + } + return true; + } + + SmallVector createDestOps(LLVM::TruncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + PTXBuilder builder; + auto &cvt = *builder.create("cvt.u16.u32"); + auto res = builder.newOperand("=h"); + auto operand = builder.newOperand(operands[0][0], "r"); + cvt(res, operand); + return {builder.launch(rewriter, loc, i16_ty, false)}; + } +}; + +struct SExtOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + static bool isLegalOp(LLVM::SExtOp op) { + auto retTy = op.getResult().getType(); + auto srcTy = op.getOperand().getType(); + if (retTy.isInteger(32) && srcTy.isInteger(16)) { + return false; + } + return true; + } + + SmallVector createDestOps(LLVM::SExtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + PTXBuilder builder; + auto &cvt = *builder.create("cvt.s32.s16"); + auto res = builder.newOperand("=r"); + auto operand = builder.newOperand(operands[0][0], "h"); + cvt(res, operand); + return {builder.launch(rewriter, loc, i32_ty, false)}; + } +}; + +struct ZExtOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + static bool isLegalOp(LLVM::ZExtOp op) { + auto retTy = op.getResult().getType(); + auto srcTy = op.getOperand().getType(); + if (retTy.isInteger(32) && srcTy.isInteger(16)) { + return false; + } + return true; + } + + SmallVector createDestOps(LLVM::ZExtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + PTXBuilder builder; + auto &cvt = *builder.create("cvt.u32.u16"); + auto res = builder.newOperand("=r"); + auto operand = builder.newOperand(operands[0][0], "h"); + cvt(res, operand); + return {builder.launch(rewriter, loc, i32_ty, false)}; + } +}; + +bool isLegalElementwiseOp(Operation *op) { + if (isa(op)) { + return FPExtOpConversion::isLegalOp(cast(op)); + } else if (isa(op)) { + return FPTruncOpConversion::isLegalOp(cast(op)); + } else if (isa(op)) { + return TruncOpConversion::isLegalOp(cast(op)); + } else if (isa(op)) { + return SExtOpConversion::isLegalOp(cast(op)); + } else if (isa(op)) { + return ZExtOpConversion::isLegalOp(cast(op)); + } + return true; +} + +void populateElementwiseOpToPTXPatterns( + TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); +} diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h index 2187f798b2b3..fbcbe95bd85b 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h @@ -8,8 +8,15 @@ using namespace mlir::triton; void populateElementwiseOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit); + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, + ModuleAllocation &allocation, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + int computeCapability, PatternBenefit benefit); bool isLegalElementwiseOp(Operation *op); +void populateElementwiseOpToPTXPatterns( + TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit); + #endif diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index d64cb43d63cc..5e4fa3c111d5 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -3,11 +3,21 @@ #include "ConvertLayoutOpToLLVM.h" #include "LoadStoreOpToLLVM.h" +#include "Utility.h" + +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" + +#include using namespace mlir; using namespace mlir::triton; +using ::mlir::LLVM::delinearize; using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::getCTALayout; +using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::SharedEncodingAttr; @@ -64,6 +74,9 @@ struct LoadOpConversion Value other = op.getOther(); // adaptor values + assert(!isTensorPointerType(ptr.getType()) && + "Cannot convert load with a tensor pointer into LLVM; " + "this case should be transformed to normal load before lowering"); Value llPtr = adaptor.getPtr(); Value llMask = adaptor.getMask(); Value llOther = adaptor.getOther(); @@ -433,6 +446,528 @@ struct StoreOpConversion return success(); } }; +// TODO: refactor to save common logic with insertsliceasyncv2 +struct StoreAsyncOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::StoreAsyncOp>::ConvertTritonGPUOpToLLVMPattern; + + StoreAsyncOpConversion(TritonGPUToLLVMTypeConverter &converter, + ModuleAllocation &allocation, + mlir::triton::gpu::TMAMetadataTy *tmaMetadata, + const TensorPtrMapT *tensorPtrMap, + PatternBenefit benefit) + : ConvertTritonGPUOpToLLVMPattern( + converter, allocation, tmaMetadata, benefit), + tensorPtrMap(tensorPtrMap) {} + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::StoreAsyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcTy = op.getSrc().getType().cast(); + auto srcEncoding = srcTy.getEncoding(); + if (srcEncoding.isa()) { + return lowerStoreAsyncWithSlice(op, adaptor, rewriter); + } else { + return lowerStoreAsync(op, adaptor, rewriter); + } + } + + LogicalResult lowerStoreAsync(triton::nvidia_gpu::StoreAsyncOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + auto dst = op.getDst(); + auto src = op.getSrc(); + auto srcTy = src.getType().cast(); + auto elemTy = srcTy.getElementType(); + + auto rank = srcTy.getRank(); + // The sotre async op only supports tensor with ranke <= 5. + // Reference: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-dimension-size-and-format + assert(rank > 0 && rank <= 5); + + auto moduleOp = op->getParentOfType(); + assert(moduleOp && "Parent ModuleOp not found for StoreAsyncOp"); + + auto llFuncOp = op->getParentOfType(); + assert(llFuncOp && "LLVMFuncOp not found for StoreAsyncOp"); + + int numTMADescs = getNumTMADescs(llFuncOp); + assert(numTMADescs > 0); + + auto sharedLayout = srcTy.getEncoding().dyn_cast(); + assert(sharedLayout && "expected shared encoding"); + + mlir::triton::gpu::TMAInfo tmaInfo; + + tmaInfo.tensorDataType = getCUtensorMapDataType(elemTy); + tmaInfo.tensorRank = rank; + assert(tmaMetadata); + + auto inOrder = sharedLayout.getOrder(); + unsigned TMADescIdx = tmaMetadata->size(); + unsigned numFuncArgs = llFuncOp.getBody().front().getNumArguments(); + auto makeTensorPtr = tensorPtrMap->lookup(op.getOperation()); + auto dstOrder = makeTensorPtr.getOrder(); + + unsigned globalAddressArgIdx = getArgIdx(makeTensorPtr.getBase()); + tmaInfo.globalAddressArgIdx = globalAddressArgIdx; + tmaInfo.TMADescArgIdx = numFuncArgs - numTMADescs + TMADescIdx; + + auto getDimOfOrder = [](ArrayRef order, int32_t i) { + auto it = std::find(order.begin(), order.end(), i); + assert(it != order.end()); + return std::distance(order.begin(), it); + }; + + std::vector globalDimsArgIdx; + std::vector globalStridesArgIdx; + // constant values are mapped to (-1 - value) + for (int i = 0; i < rank; ++i) { + int32_t argIdx = -1; + auto dim = getDimOfOrder(dstOrder, i); + argIdx = getArgIdx(makeTensorPtr.getShape()[dim]); + globalDimsArgIdx.emplace_back(argIdx); + // handle constant stride + argIdx = getArgIdx(makeTensorPtr.getStrides()[dim]); + globalStridesArgIdx.emplace_back(argIdx); + } + + tmaInfo.globalDimsArgIdx = globalDimsArgIdx; + tmaInfo.globalStridesArgIdx = globalStridesArgIdx; + std::vector boxDims; + auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA(); + auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder(); + auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum(); + auto tensorShape = makeTensorPtr.getResult() + .getType() + .cast() + .getPointeeType() + .cast() + .getShape(); + auto shapePerCTA = getShapePerCTA(CTASplitNum, tensorShape); + const uint32_t bytesPerCacheline = 128; + uint32_t bytesPerElem = elemTy.getIntOrFloatBitWidth() / 8; + uint32_t numBox{1}; + for (int i = 0; i < rank; ++i) { + auto dim = getDimOfOrder(dstOrder, i); + auto tNumElems = shapePerCTA[dim]; + if (i == 0 && tNumElems * bytesPerElem > bytesPerCacheline) { + tNumElems = bytesPerCacheline / bytesPerElem; + numBox = (shapePerCTA[dim] + tNumElems - 1) / tNumElems; + } + boxDims.emplace_back(tNumElems); + } + std::vector elementStrides(rank, 1); + tmaInfo.boxDims = boxDims; + tmaInfo.elementStrides = elementStrides; + + CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE; + assert( + ((elemTy.getIntOrFloatBitWidth() == 16 && sharedLayout.getVec() == 8) or + (elemTy.getIntOrFloatBitWidth() == 32 && + sharedLayout.getVec() == 4)) && + "Unexpected shared layout for StoreAsyncOp"); + if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2) + swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B; + else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4) + swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B; + else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8) + swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B; + else + llvm::report_fatal_error("Unsupported shared layout for StoreAsyncOp"); + tmaInfo.swizzle = swizzle; + tmaInfo.interleave = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE; + tmaInfo.l2Promotion = + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + tmaInfo.oobFill = + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + + tmaMetadata->emplace_back(tmaInfo); + + Value llDst = adaptor.getDst(); + Value llSrc = adaptor.getSrc(); + auto srcShape = srcTy.getShape(); + auto smemObj = getSharedMemoryObjectFromStruct(loc, llSrc, rewriter); + + SmallVector offsetVals; + for (auto i = 0; i < srcShape.size(); ++i) { + offsetVals.emplace_back(i32_val(0)); + } + + Value tmaDesc = + llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx); + auto ptrI8SharedTy = LLVM::LLVMPointerType::get( + typeConverter->convertType(rewriter.getI8Type()), 3); + + auto threadId = getThreadId(rewriter, loc); + Value pred = icmp_eq(threadId, i32_val(0)); + + auto llCoord = getTypeConverter()->unpackLLElements(loc, llDst, rewriter, + dst.getType()); + uint32_t boxStride = std::accumulate(boxDims.begin(), boxDims.end(), 1, + std::multiplies()); + + Value clusterCTAId = getClusterCTAId(rewriter, loc); + SmallVector multiDimClusterCTAId = + delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); + + rewriter.create(loc, 0); + + for (uint32_t b = 0; b < numBox; ++b) { + SmallVector coord; + // raw coord + for (int i = 0; i < rank; ++i) { + auto dim = getDimOfOrder(dstOrder, i); + coord.push_back(llCoord[dim]); + } + // coord with box and cta offset + for (int i = 0; i < rank; ++i) { + auto dim = getDimOfOrder(dstOrder, i); + if (i == 0) { + coord[i] = add(coord[i], i32_val(b * boxDims[i])); + auto CTAOffset = + mul(multiDimClusterCTAId[dim], i32_val(numBox * boxDims[i])); + coord[i] = add(coord[i], CTAOffset); + } else { + coord[i] = add(coord[i], + mul(multiDimClusterCTAId[dim], i32_val(boxDims[i]))); + } + } + Value srcOffset = i32_val(b * boxStride); + auto srcPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); + Value srcPtrBase = gep(srcPtrTy, smemObj.base, srcOffset); + auto addr = bitcast(srcPtrBase, ptrI8SharedTy); + rewriter.create(loc, tmaDesc, addr, pred, + coord); + } + rewriter.eraseOp(op); + return success(); + } + + LogicalResult + lowerStoreAsyncWithSlice(triton::nvidia_gpu::StoreAsyncOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + auto dst = op.getDst(); + auto src = op.getSrc(); + auto srcTy = src.getType().cast(); + auto makeTensorPtr = tensorPtrMap->lookup(op.getOperation()); + auto dstTensorTy = makeTensorPtr.getResult() + .getType() + .cast() + .getPointeeType() + .cast(); + auto tensorShape = dstTensorTy.getShape(); + auto dstOrder = makeTensorPtr.getOrder(); + auto dstElemTy = dstTensorTy.getElementType(); + + auto rank = srcTy.getRank(); + // The sotre async op only supports tensor with ranke <= 5. + // Reference: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-dimension-size-and-format + assert(rank > 0 && rank <= 5); + + auto moduleOp = op->getParentOfType(); + assert(moduleOp && "Parent ModuleOp not found for StoreAsyncOp"); + + auto llFuncOp = op->getParentOfType(); + assert(llFuncOp && "LLVMFuncOp not found for StoreAsyncOp"); + + int numTMADescs = getNumTMADescs(llFuncOp); + assert(numTMADescs > 0); + + auto ctaLayout = getCTALayout(dstTensorTy.getEncoding()); + // The order of smem should be consistent with gmem. + SmallVector sharedOrder; + for (auto o : makeTensorPtr.getOrder()) { + sharedOrder.emplace_back(o); + } + auto sharedLayout = SharedEncodingAttr::get(ctx, tensorShape, sharedOrder, + ctaLayout, dstElemTy); + + mlir::triton::gpu::TMAInfo tmaInfo; + + tmaInfo.tensorDataType = getCUtensorMapDataType(dstElemTy); + tmaInfo.tensorRank = rank; + assert(tmaMetadata); + + unsigned TMADescIdx = tmaMetadata->size(); + unsigned numFuncArgs = llFuncOp.getBody().front().getNumArguments(); + + unsigned globalAddressArgIdx = getArgIdx(makeTensorPtr.getBase()); + tmaInfo.globalAddressArgIdx = globalAddressArgIdx; + tmaInfo.TMADescArgIdx = numFuncArgs - numTMADescs + TMADescIdx; + + auto getDimOfOrder = [](ArrayRef order, int32_t i) { + auto it = std::find(order.begin(), order.end(), i); + assert(it != order.end()); + return std::distance(order.begin(), it); + }; + + std::vector globalDimsArgIdx; + std::vector globalStridesArgIdx; + // constant values are mapped to (-1 - value) + for (int i = 0; i < rank; ++i) { + int32_t argIdx = -1; + auto dim = getDimOfOrder(dstOrder, i); + argIdx = getArgIdx(makeTensorPtr.getShape()[dim]); + globalDimsArgIdx.emplace_back(argIdx); + // handle constant stride + argIdx = getArgIdx(makeTensorPtr.getStrides()[dim]); + globalStridesArgIdx.emplace_back(argIdx); + } + + tmaInfo.globalDimsArgIdx = globalDimsArgIdx; + tmaInfo.globalStridesArgIdx = globalStridesArgIdx; + std::vector boxDims; + auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA(); + auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder(); + auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum(); + auto shapePerCTA = getShapePerCTA(CTASplitNum, tensorShape); + + auto srcLayout = srcTy.getEncoding(); + auto mmaLayout = srcLayout.dyn_cast(); + + unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); + + auto instrShape = mmaLayout.getInstrShape(); + auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); + uint32_t repM = + ceil(shapePerCTA[0], instrShape[0] * warpsPerCTA[0]); + uint32_t numElemsPerRep = numElems / repM; + + const uint32_t bytesPerCacheline = 128; + uint32_t bytesPerElem = dstElemTy.getIntOrFloatBitWidth() / 8; + uint32_t numBox{1}; + for (int i = 0; i < rank; ++i) { + auto dim = getDimOfOrder(dstOrder, i); + auto tNumElems = shapePerCTA[dim]; + if (i == 0 && tNumElems * bytesPerElem > bytesPerCacheline) { + tNumElems = bytesPerCacheline / bytesPerElem; + numBox = (shapePerCTA[dim] + tNumElems - 1) / tNumElems; + } + if (i == 1) { + tNumElems = tNumElems / repM / warpsPerCTA[0]; + } + boxDims.emplace_back(tNumElems); + } + std::vector elementStrides(rank, 1); + tmaInfo.boxDims = boxDims; + tmaInfo.elementStrides = elementStrides; + + CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE; + assert(((dstElemTy.getIntOrFloatBitWidth() == 16 && + sharedLayout.getVec() == 8) or + (dstElemTy.getIntOrFloatBitWidth() == 32 && + sharedLayout.getVec() == 4)) && + "Unexpected shared layout for StoreAsyncOp"); + if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2) + swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B; + else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4) + swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B; + else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8) + swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B; + else + llvm::report_fatal_error("Unsupported shared layout for StoreAsyncOp"); + tmaInfo.swizzle = swizzle; + tmaInfo.interleave = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE; + tmaInfo.l2Promotion = + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + tmaInfo.oobFill = + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + + tmaMetadata->emplace_back(tmaInfo); + + Value llDst = adaptor.getDst(); + Value llSrc = adaptor.getSrc(); + auto srcShape = srcTy.getShape(); + auto dstElemPtrTy = ptr_ty(getTypeConverter()->convertType(dstElemTy), 3); + Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); + smemBase = bitcast(smemBase, dstElemPtrTy); + + SmallVector offsetVals; + for (auto i = 0; i < srcShape.size(); ++i) { + offsetVals.emplace_back(i32_val(0)); + } + + Value tmaDesc = + llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx); + auto ptrI8SharedTy = LLVM::LLVMPointerType::get( + typeConverter->convertType(rewriter.getI8Type()), 3); + + auto threadId = getThreadId(rewriter, loc); + Value pred = icmp_eq(urem(threadId, i32_val(32)), i32_val(0)); + + auto llCoord = getTypeConverter()->unpackLLElements(loc, llDst, rewriter, + dst.getType()); + uint32_t boxStride = std::accumulate(boxDims.begin(), boxDims.end(), 1, + std::multiplies()); + boxStride = boxStride * repM * warpsPerCTA[0]; + + Value clusterCTAId = getClusterCTAId(rewriter, loc); + SmallVector multiDimClusterCTAId = + delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); + + // rowStride in bytes + uint32_t rowStrideInBytes = shapePerCTA[dstOrder[0]] * bytesPerElem; + uint32_t swizzlingByteWidth = + std::min(rowStrideInBytes, bytesPerCacheline); + + unsigned numElemsPerSwizzlingRow = swizzlingByteWidth / bytesPerElem; + unsigned leadingDimOffset = + numElemsPerSwizzlingRow * shapePerCTA[dstOrder[1]]; + + uint32_t rowsPerRep = getShapePerCTATile(mmaLayout)[0]; + + Value warpId = udiv(threadId, i32_val(32)); + Value warpId0 = urem(urem(warpId, i32_val(warpsPerCTA[0])), + i32_val(srcShape[0] / instrShape[0])); + auto srcOrder = triton::gpu::getOrder(srcLayout); + unsigned inVec = + srcOrder == sharedLayout.getOrder() + ? triton::gpu::getContigPerThread(srcLayout)[srcOrder[0]] + : 1; + unsigned outVec = sharedLayout.getVec(); + unsigned minVec = std::min(outVec, inVec); + assert(minVec == 2); + + auto wordTy = vec_ty(dstElemTy, minVec); + + auto inVals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(), + rewriter, srcTy); + for (uint32_t b = 0; b < numBox; ++b) { + for (int rep = 0; rep < repM; ++rep) { + Value rowOfWarp = add(mul(warpId0, i32_val(instrShape[0])), + i32_val(rep * rowsPerRep)); + uint32_t elemIdxOffset = rep * numElemsPerRep; + + for (unsigned idx = 0; idx < numElemsPerRep / numBox; idx += 8) { + uint32_t elemIdx = elemIdxOffset + b * numElemsPerRep / numBox + idx; + + Value offset = rewriter.create( + loc, i32_ty, threadId, rowOfWarp, + i32_val(b * numElemsPerRep / numBox + idx), leadingDimOffset, + numElemsPerSwizzlingRow, true); + + Value addr = gep(dstElemPtrTy, smemBase, offset); + Value words[4]; + for (unsigned i = 0; i < 8; ++i) { + if (i % minVec == 0) + words[i / 2] = undef(wordTy); + words[i / 2] = insert_element( + wordTy, words[i / 2], inVals[elemIdx + i], i32_val(i % minVec)); + } + + rewriter.create( + loc, bitcast(addr, ptrI8SharedTy), + ValueRange{bitcast(words[0], i32_ty), bitcast(words[1], i32_ty), + bitcast(words[2], i32_ty), bitcast(words[3], i32_ty)}); + } + rewriter.create(loc, 0); + + SmallVector coord; + // raw coord + for (int i = 0; i < rank; ++i) { + auto dim = getDimOfOrder(dstOrder, i); + coord.push_back(llCoord[dim]); + } + // coord with box and cta offset + for (int i = 0; i < rank; ++i) { + auto dim = getDimOfOrder(dstOrder, i); + if (i == 0) { + coord[i] = add(coord[i], i32_val(b * boxDims[i])); + auto CTAOffset = + mul(multiDimClusterCTAId[dim], i32_val(numBox * boxDims[i])); + coord[i] = add(coord[i], CTAOffset); + } else { + Value blockOffset = i32_val(rep * instrShape[0] * warpsPerCTA[0]); + Value warpOffset = mul(warpId0, i32_val(instrShape[0])); + coord[i] = add(add(coord[i], add(blockOffset, warpOffset)), + mul(multiDimClusterCTAId[dim], + i32_val(boxDims[i] * repM * warpsPerCTA[0]))); + } + } + Value srcOffset = + add(i32_val(b * boxStride + rep * instrShape[0] * warpsPerCTA[0] * + instrShape[1] * warpsPerCTA[1] / + numBox), + mul(warpId0, i32_val(instrShape[0] * numElemsPerSwizzlingRow))); + auto srcPtrTy = ptr_ty(getTypeConverter()->convertType(dstElemTy), 3); + Value srcPtrBase = gep(srcPtrTy, smemBase, srcOffset); + auto addr = bitcast(srcPtrBase, ptrI8SharedTy); + rewriter.create(loc, tmaDesc, addr, + pred, coord); + } + } + rewriter.eraseOp(op); + return success(); + } + +private: + CUtensorMapDataType getCUtensorMapDataType(Type ty) const { + if (ty.isF16()) { + return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } else if (ty.isF32()) { + return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + } else { + llvm::report_fatal_error("Unsupported elemTy for StoreAsyncOp"); + return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } + } + + unsigned getArgIdx(Value v) const { + if (auto op = v.getDefiningOp()) { + return -1 - + op.getValue().dyn_cast().getValue().getZExtValue(); + } + if (v.getDefiningOp() && + isa(v.getDefiningOp())) { + return getArgIdx(v.getDefiningOp()->getOperand(0)); + } else if (v.getParentBlock()->isEntryBlock() && v.isa()) { + // in entryblock and is BlockArgument; Because argument of func are + // arugments of entryblock bb0 in MLIR + return v.cast().getArgNumber(); + } else if (v.getParentBlock()->isEntryBlock() && + (!v.isa())) { + // in entryblock but not BlockArgument + return getArgIdx(v.getDefiningOp()->getOperand(0)); + } else if (!v.getParentBlock()->isEntryBlock()) { + // in non-entryblock + return getArgIdx(v.getDefiningOp()->getOperand(0)); + } else { + llvm::report_fatal_error( + "Operand of `MakeTensorPtrOp` is not the function's argument"); + return 0; + } + } + + int getNumTMADescs(LLVM::LLVMFuncOp func) const { + if (!func->hasAttr(kAttrNumTMALoadDescsName)) { + llvm::report_fatal_error("TritonGPU module should contain a " + "triton_gpu.num-tma-load attribute"); + return -1; + } + if (!func->hasAttr(kAttrNumTMAStoreDescsName)) { + llvm::report_fatal_error("TritonGPU module should contain a " + "triton_gpu.num-tma-store attribute"); + return -1; + } + return func->getAttr(kAttrNumTMAStoreDescsName) + .cast() + .getInt() + + func->getAttr(kAttrNumTMALoadDescsName).cast().getInt(); + } + + const TensorPtrMapT *tensorPtrMap; +}; struct AtomicCASOpConversion : public ConvertTritonGPUOpToLLVMPattern, @@ -1126,11 +1661,389 @@ struct InsertSliceAsyncOpConversion } }; +struct InsertSliceAsyncV2OpConversion + : public ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::InsertSliceAsyncV2Op> { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::InsertSliceAsyncV2Op>:: + ConvertTritonGPUOpToLLVMPattern; + + InsertSliceAsyncV2OpConversion(TritonGPUToLLVMTypeConverter &converter, + + ModuleAllocation &allocation, + mlir::triton::gpu::TMAMetadataTy *tmaMetadata, + const TensorPtrMapT *tensorPtrMap, + PatternBenefit benefit) + : ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::InsertSliceAsyncV2Op>(converter, allocation, + tmaMetadata, benefit), + tensorPtrMap(tensorPtrMap) {} + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::InsertSliceAsyncV2Op op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + auto resultTy = op.getResult().getType().cast(); + auto elemTy = resultTy.getElementType(); + auto rank = resultTy.getRank() - 1; + + // TODO: support any valid rank in (3, 4, 5) + // The sotre async op only supports tensor with ranke <= 5. + // Reference: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-dimension-size-and-format + assert(rank > 0 && rank <= 5); + SmallVector shape; + auto axis = op->getAttrOfType("axis").getInt(); + auto moduleOp = op->getParentOfType(); + assert(moduleOp && "Parent ModuleOp not found for InsertSliceAsyncV2Op"); + auto llFuncOp = op->getParentOfType(); + assert(llFuncOp && "LLVMFuncOp not found for InsertSliceAsyncV2Op"); + int numTMADescs = getNumTMADescs(llFuncOp); + assert(numTMADescs > 0); + auto sharedLayout = resultTy.getEncoding().dyn_cast(); + assert(sharedLayout && "unexpected layout of InsertSliceAsyncV2Op"); + auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA(); + auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder(); + auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum(); + + mlir::triton::gpu::TMAInfo tmaInfo; + + tmaInfo.tensorDataType = getCUtensorMapDataType(elemTy); + tmaInfo.tensorRank = rank; + + assert(tmaMetadata); + unsigned TMADescIdx = tmaMetadata->size(); + unsigned numFuncArgs = llFuncOp.getBody().front().getNumArguments(); + auto makeTensorPtr = tensorPtrMap->lookup(op.getOperation()); + auto inOrder = makeTensorPtr.getOrder(); + unsigned globalAddressArgIdx = getArgIdx(makeTensorPtr.getBase()); + tmaInfo.globalAddressArgIdx = globalAddressArgIdx; + tmaInfo.TMADescArgIdx = numFuncArgs - numTMADescs + TMADescIdx; + + auto getDimOfOrder = [](ArrayRef order, int32_t i) { + auto it = std::find(order.begin(), order.end(), i); + assert(it != order.end()); + return std::distance(order.begin(), it); + }; + + std::vector globalDimsArgIdx; + std::vector globalStridesArgIdx; + // constant values are mapped to (-1 - value) + for (int i = 0; i < rank; ++i) { + int32_t argIdx = -1; + auto dim = getDimOfOrder(inOrder, i); + argIdx = getArgIdx(makeTensorPtr.getShape()[dim]); + globalDimsArgIdx.emplace_back(argIdx); + // handle constant stride + argIdx = getArgIdx(makeTensorPtr.getStrides()[dim]); + globalStridesArgIdx.emplace_back(argIdx); + } + + tmaInfo.globalDimsArgIdx = globalDimsArgIdx; + tmaInfo.globalStridesArgIdx = globalStridesArgIdx; + + std::vector boxDims; + auto tensorShape = makeTensorPtr.getResult() + .getType() + .cast() + .getPointeeType() + .cast() + .getShape(); + + SmallVector numMcast(rank); + unsigned accNumMcast = 1; + for (unsigned i = 0; i < rank; ++i) { + numMcast[i] = CTAsPerCGA[i] / CTASplitNum[i]; + accNumMcast *= numMcast[i]; + } + auto shapePerCTA = getShapePerCTA(CTASplitNum, tensorShape); + for (size_t i = 0; i < rank; ++i) { + auto dim = getDimOfOrder(inOrder, i); + // in case of TMA multicast, we should always slice along higher order + // dimensions + if (i == rank - 1) { + assert(shapePerCTA[dim] >= accNumMcast && + "cases when the size of the highest order is smaller " + "than numMcasts is not implemented"); + boxDims.emplace_back(shapePerCTA[dim] / accNumMcast); + } else { + boxDims.emplace_back(shapePerCTA[dim]); + } + } + + std::vector elementStrides(rank, 1); + tmaInfo.elementStrides = elementStrides; + + CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE; + if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2) + swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B; + else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4) + swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B; + else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8) + swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B; + else + llvm::report_fatal_error( + "Unsupported shared layout for InsertSliceAsyncV2Op"); + + tmaInfo.swizzle = swizzle; + tmaInfo.interleave = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE; + tmaInfo.l2Promotion = + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + tmaInfo.oobFill = + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + + uint32_t numBoxes = 1; + uint32_t elemSizeOfBytes = elemTy.getIntOrFloatBitWidth() / 8; + if (swizzle == CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) { + while (elemSizeOfBytes * boxDims[0] > 128) { + boxDims[0] = boxDims[0] / 2; + numBoxes *= 2; + } + } + tmaInfo.boxDims = boxDims; + tmaMetadata->emplace_back(tmaInfo); + + uint32_t elemsPerBox = + std::accumulate(boxDims.begin(), boxDims.end(), 1, std::multiplies{}); + + Value clusterCTAId = getClusterCTAId(rewriter, loc); + SmallVector multiDimClusterCTAId = + delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); + + Value llDst = adaptor.getDst(); + Value llIndex = adaptor.getIndex(); + Value src = op.getSrc(); + Value dst = op.getDst(); + auto dstTy = dst.getType().cast(); + auto dstShape = dstTy.getShape(); + auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter); + + // the offset of coord considering multicast slicing + SmallVector mcastOffsetVals; + // The index of slice is this CTAId is responsible for + SmallVector multiDimSliceIdx(rank); + for (auto i = 0; i < rank; ++i) + multiDimSliceIdx[i] = + udiv(multiDimClusterCTAId[i], i32_val(CTASplitNum[i])); + Value sliceIdx = + linearize(rewriter, loc, multiDimSliceIdx, numMcast, CTAOrder); + + Value sliceCoord; + for (auto i = 0; i < rank; ++i) { + if (inOrder[i] == rank - 1) { + // TODO[goostavz]: Cases when the size of the highest order is smaller + // than numMcasts is not implemented. + sliceCoord = mul(sliceIdx, i32_val(shapePerCTA[i] / accNumMcast)); + mcastOffsetVals.emplace_back( + mul(sliceIdx, i32_val(shapePerCTA[i] / accNumMcast))); + } else { + mcastOffsetVals.emplace_back(i32_val(0)); + } + } + + uint32_t elemsPerSlice = std::accumulate( + shapePerCTA.begin(), shapePerCTA.end(), 1, std::multiplies{}); + Value dstOffsetCommon = mul(llIndex, i32_val(elemsPerSlice)); + // [benzh] sliceCoord should be higher dimension's multiplier accumulate. + // currently only support rank == 2. + dstOffsetCommon = + add(dstOffsetCommon, mul(sliceCoord, i32_val(boxDims[0]))); + auto dstPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); + + Value tmaDesc = + llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx); + // TODO: sink this logic into Triton::NVGPU dialect and support more + // cache-policy modes + Value l2Desc = int_val(64, 0x1000000000000000ll); + + auto ptrI8SharedTy = LLVM::LLVMPointerType::get( + typeConverter->convertType(rewriter.getI8Type()), 3); + + SmallVector coordCommon; + auto llCoord = getTypeConverter()->unpackLLElements( + loc, adaptor.getSrc(), rewriter, src.getType()); + + for (int i = 0; i < rank; ++i) { + auto dim = getDimOfOrder(inOrder, i); + Value coordDim = bitcast(llCoord[dim], i32_ty); + if (CTASplitNum[dim] != 1) { + // Add offset for each CTA + // boxDims[i] * (multiDimClusterCTAId[i] % CTASplitNum[i]); + auto CTAOffset = + mul(i32_val(shapePerCTA[dim]), + urem(multiDimClusterCTAId[dim], i32_val(CTASplitNum[dim]))); + coordDim = add(coordDim, CTAOffset); + } + + if (i == rank - 1) + // Add offset in case of multicast slicing + coordCommon.push_back(add(coordDim, mcastOffsetVals[dim])); + else + coordCommon.push_back(coordDim); + } + + auto threadId = getThreadId(rewriter, loc); + Value pred = icmp_eq(threadId, i32_val(0)); + + auto mask = adaptor.getMask(); + if (mask) { + // TODO(thomas): What is the right implementation for this case? + assert(mask.getType().isInteger(1) && + "need to implement cases with tensor mask"); + pred = rewriter.create(loc, pred, mask); + } + + Value mcastMask = getMCastMask(sharedLayout, rewriter, loc, clusterCTAId); + + for (size_t i = 0; i < numBoxes; ++i) { + Value dstOffset = + add(dstOffsetCommon, i32_val(i * elemsPerBox * accNumMcast)); + Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset); + SmallVector coord = coordCommon; + coord[0] = add(coordCommon[0], i32_val(i * boxDims[0])); + rewriter.create( + loc, bitcast(dstPtrBase, ptrI8SharedTy), adaptor.getMbar(), tmaDesc, + l2Desc, pred, coord, mcastMask); + } + + rewriter.replaceOp(op, llDst); + return success(); + } + +private: + Value getMCastMask(const SharedEncodingAttr &sharedLayout, + ConversionPatternRewriter &rewriter, Location loc, + Value clusterCTAId) const { + auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA(); + auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder(); + auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum(); + + // Short path when no multicast is needed + if (CTAsPerCGA == CTASplitNum) + return nullptr; + + // Short path when bcastMask is a constant + bool isConstMcastMask = true; + for (unsigned s : CTASplitNum) { + if (s > 1) { + isConstMcastMask = false; + break; + } + } + if (isConstMcastMask) { + unsigned numCTAs = std::accumulate(CTAsPerCGA.begin(), CTAsPerCGA.end(), + 1, std::multiplies{}); + return int_val(/*width*/ 16, (1u << numCTAs) - 1); + } + + SmallVector multiDimCTAId = + delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); + auto rank = CTAOrder.size(); + SmallVector> multiDimMask(rank); + unsigned accNumMcast = 1; + SmallVector numMcast(rank); + for (unsigned i = 0; i < rank; ++i) { + // For the ith dimension, CTAsPerCGA[i]/CTASplitNum[i] vals is to be + // broadcasted, which for this CTAId is: + // multiDimCTAId[i] % CTASplitNum[i] + (0 .. + // (CTAsPerCGA[i]/CTASplitNum[i] - 1)) * CTASplitNum[i] + // TODO: will there be cases if CTAsPerCGA[i]/CTASplitNum[i] < 1? + Value rem = urem(multiDimCTAId[i], i32_val(CTASplitNum[i])); + numMcast[i] = CTAsPerCGA[i] / CTASplitNum[i]; + accNumMcast *= numMcast[i]; + for (unsigned j = 0; j < numMcast[i]; ++j) { + if (j == 0) { + multiDimMask[i].push_back(rem); + } else { + multiDimMask[i].push_back(add(rem, i32_val(j * CTASplitNum[i]))); + } + } + } + + Value bcastMask = int_val(/*width*/ 16, 0); + Value _1_i16 = int_val(/*width*/ 16, 1); + for (unsigned i = 0; i < accNumMcast; ++i) { + SmallVector multiDimIdx = + getMultiDimIndex(i, numMcast, CTAOrder); + SmallVector multiDimMaskedCTAId(rank); + for (unsigned dim = 0; dim < rank; ++dim) { + multiDimMaskedCTAId[dim] = multiDimMask[dim][multiDimIdx[dim]]; + } + Value bcastCTAId = + linearize(rewriter, loc, multiDimMaskedCTAId, CTAsPerCGA, CTAOrder); + // bcastMask |= 1u << bcastCTAId; + bcastMask = or_(bcastMask, shl(_1_i16, trunc(i16_ty, bcastCTAId))); + } + + return bcastMask; + } + + CUtensorMapDataType getCUtensorMapDataType(Type ty) const { + if (ty.isF16()) { + return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } else if (ty.isF32()) { + return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + } else { + llvm::report_fatal_error("Unsupported elemTy for InsertSliceAsyncV2Op"); + return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } + } + + unsigned getArgIdx(Value v) const { + if (auto op = v.getDefiningOp()) { + return -1 - + op.getValue().dyn_cast().getValue().getZExtValue(); + } + if (v.getDefiningOp() && + isa(v.getDefiningOp())) { + return getArgIdx(v.getDefiningOp()->getOperand(0)); + } else if (v.getParentBlock()->isEntryBlock() && v.isa()) { + // in entryblock and is BlockArgument; Because argument of func are + // arugments of entryblock bb0 in MLIR + return v.cast().getArgNumber(); + } else if (v.getParentBlock()->isEntryBlock() && + (!v.isa())) { + // in entryblock but not BlockArgument + return getArgIdx(v.getDefiningOp()->getOperand(0)); + } else if (!v.getParentBlock()->isEntryBlock()) { + // in non-entryblock + return getArgIdx(v.getDefiningOp()->getOperand(0)); + } else { + llvm::report_fatal_error( + "Operand of `MakeTensorPtrOp` is not the function's argument"); + return 0; + } + } + + int getNumTMADescs(LLVM::LLVMFuncOp func) const { + if (!func->hasAttr(kAttrNumTMALoadDescsName)) { + llvm::report_fatal_error("TritonGPU module should contain a " + "triton_gpu.num-tma-load attribute"); + return -1; + } + if (!func->hasAttr(kAttrNumTMAStoreDescsName)) { + llvm::report_fatal_error("TritonGPU module should contain a " + "triton_gpu.num-tma-store attribute"); + return -1; + } + return func->getAttr(kAttrNumTMAStoreDescsName) + .cast() + .getInt() + + func->getAttr(kAttrNumTMALoadDescsName).cast().getInt(); + } + + const TensorPtrMapT *tensorPtrMap; +}; + void populateLoadStoreOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, + ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, - PatternBenefit benefit) { + mlir::triton::gpu::TMAMetadataTy *tmaMetadata, + const TensorPtrMapT *tensorPtrMap, PatternBenefit benefit) { patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, allocation, @@ -1141,4 +2054,8 @@ void populateLoadStoreOpToLLVMPatterns( indexCacheInfo, benefit); patterns.add( typeConverter, allocation, indexCacheInfo, axisInfoAnalysis, benefit); + patterns.add( + typeConverter, allocation, tmaMetadata, tensorPtrMap, benefit); + patterns.add(typeConverter, allocation, tmaMetadata, + tensorPtrMap, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.h index b3f9f52afff2..197edd77aa7a 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.h @@ -8,8 +8,10 @@ using namespace mlir::triton; void populateLoadStoreOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, + ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, - PatternBenefit benefit); + mlir::triton::gpu::TMAMetadataTy *tmaMetadata, + const TensorPtrMapT *tensorPtrMap, PatternBenefit benefit); #endif diff --git a/lib/Conversion/TritonGPUToLLVM/PTXAsmFormat.cpp b/lib/Conversion/TritonGPUToLLVM/PTXAsmFormat.cpp index 7c44921bf9eb..034cc0746617 100644 --- a/lib/Conversion/TritonGPUToLLVM/PTXAsmFormat.cpp +++ b/lib/Conversion/TritonGPUToLLVM/PTXAsmFormat.cpp @@ -51,6 +51,14 @@ PTXBuilder::Operand *PTXBuilder::newOperand(StringRef constraint, bool init) { return opr; } +PTXBuilder::Operand *PTXBuilder::newOperand(unsigned operandIndex) { + assert(operandIndex < oprCounter && "operand index out of range"); + auto *opr = newOperand(); + opr->idx = oprCounter++; + opr->constraint = std::to_string(operandIndex); + return opr; +} + PTXBuilder::Operand *PTXBuilder::newConstantOperand(const std::string &v) { argArchive.emplace_back(std::make_unique()); argArchive.back()->repr = [v](int idx) { return v; }; diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 1555dfa514de..e3ee4beb84fd 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -1,5 +1,8 @@ #include "ReduceOpToLLVM.h" #include "Utility.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" using namespace mlir; using namespace mlir::triton; @@ -14,8 +17,13 @@ using ::mlir::triton::gpu::getTotalElemsPerThread; struct ReduceOpConversion : public ConvertTritonGPUOpToLLVMPattern { public: - using ConvertTritonGPUOpToLLVMPattern< - triton::ReduceOp>::ConvertTritonGPUOpToLLVMPattern; + ReduceOpConversion( + TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + int computeCapability, PatternBenefit benefit) + : ConvertTritonGPUOpToLLVMPattern( + typeConverter, allocation, indexCacheInfo, benefit), + computeCapability(computeCapability) {} LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, @@ -26,14 +34,12 @@ struct ReduceOpConversion } private: + int computeCapability; + void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, - llvm::SmallVectorImpl &acc, ValueRange cur, - bool isFirst) const { + SmallVector &acc, ValueRange cur, bool isFirst) const { if (isFirst) { - acc.resize(cur.size()); - for (unsigned i = 0; i < cur.size(); ++i) { - acc[i] = cur[i]; - } + acc = SmallVector(cur.begin(), cur.end()); return; } @@ -120,7 +126,7 @@ struct ReduceOpConversion // writeIdx[originalAxis] = index[originalAxis] / axisSizePerThread writeIdx[originalAxis] = udiv(index[originalAxis], axisSizePerThread); } else if (auto mmaLayout = layout.dyn_cast()) { - if (!mmaLayout.isAmpere()) { + if (!mmaLayout.isAmpere() && !mmaLayout.isHopper()) { llvm::report_fatal_error("Unsupported layout"); } if (originalAxis == 0) { @@ -175,7 +181,6 @@ struct ReduceOpConversion elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3); } auto llvmIndexTy = getTypeConverter()->getIndexType(); - auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); auto smemShape = helper.getScratchConfigBasic(); unsigned elems = product(smemShape); @@ -189,33 +194,10 @@ struct ReduceOpConversion elemPtrTys[i]); } - unsigned srcElems = getTotalElemsPerThread(srcTys[0]); - // Emits indices of the original tensor that each thread - // would own - auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]); auto srcValues = unpackInputs(loc, op, adaptor, rewriter); - - // Emits offsets (the offset from the base index) - // of the original tensor that each thread would own - // NOTE: Assumes offsets don't actually depend on type - SmallVector> offset = - emitOffsetForLayout(srcLayout, srcTys[0]); - - // Keep track of accumulations and their indices std::map, SmallVector> accs; std::map, SmallVector> indices; - - Region *combineOp = &op.getCombineOp(); - - // reduce within threads - for (unsigned i = 0; i < srcElems; ++i) { - SmallVector key = offset[i]; - key[axis] = 0; - bool isFirst = accs.find(key) == accs.end(); - accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); - if (isFirst) - indices[key] = srcIndices[i]; - } + reduceWithinThreads(helper, srcValues, accs, indices, rewriter); // cached int32 constants std::map ints; @@ -271,15 +253,17 @@ struct ReduceOpConversion readPtrs[i] = gep(elemPtrTys[i], writePtrs[i], readOffset); } - barrier(); + sync(rewriter, loc, op); + // Combine accumulator value from another thread SmallVector cur(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { cur[i] = load(readPtrs[i]); } - accumulate(rewriter, *combineOp, acc, cur, false); + accumulate(rewriter, op.getCombineOp(), acc, cur, false); + + sync(rewriter, loc, op); - barrier(); // Publish our new accumulator value to shared memory for (unsigned i = 0; i < op.getNumOperands(); ++i) { store(acc[i], writePtrs[i]); @@ -287,7 +271,7 @@ struct ReduceOpConversion } } - barrier(); + sync(rewriter, loc, op); // set output values SmallVector results(op.getNumOperands()); @@ -324,79 +308,216 @@ struct ReduceOpConversion return success(); } - // Use warp shuffle for reduction within warps and shared memory for data - // exchange across warps - LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - ReduceOpHelper helper(op); - Location loc = op->getLoc(); - unsigned axis = adaptor.getAxis(); - - auto srcTys = op.getInputTypes(); - auto srcLayout = helper.getSrcLayout(); - if (!helper.isSupportedLayout()) { - assert(false && "Unexpected srcLayout in ReduceOpConversion"); - } - auto srcOrd = triton::gpu::getOrder(srcLayout); - auto srcShape = helper.getSrcShape(); - - SmallVector elemPtrTys(srcTys.size()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - auto ty = srcTys[i].getElementType(); - auto llvmElemTy = getTypeConverter()->convertType(ty); - elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3); - } - auto llvmIndexTy = getTypeConverter()->getIndexType(); - auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); - - auto smemShapes = helper.getScratchConfigsFast(); - unsigned elems = product(smemShapes[0]); - unsigned maxElems = std::max(elems, product(smemShapes[1])); - - unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData(); - unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData(); - - SmallVector smemBases(op.getNumOperands()); - bool isWarpSync = helper.isWarpSynchronous(); - - if (!isWarpSync) { - smemBases[0] = bitcast( - getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]); - for (unsigned i = 1; i < op.getNumOperands(); ++i) { - smemBases[i] = - bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)), - elemPtrTys[i]); - } + void sync(ConversionPatternRewriter &rewriter, Location loc, + triton::ReduceOp op) const { + // TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent + // attr. + if (getWSAgentId(op)) { + barSync(rewriter, op, getAgentIds(op).front(), 128); + } else { + barrier(); } + } - unsigned srcElems = getTotalElemsPerThread(srcTys[0]); - auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]); - auto srcValues = unpackInputs(loc, op, adaptor, rewriter); - - std::map, SmallVector> accs; - std::map, SmallVector> indices; + // Check if the reduction can use a redux op and return the kind. + std::optional matchReduxKind(triton::ReduceOp op) const { + #ifdef USE_ROCM + return std::nullopt; + #endif + if (computeCapability < 80) + return std::nullopt; + if (op.getNumOperands() != 1 || op.getNumResults() != 1) + return std::nullopt; + Block *block = &(*op.getCombineOp().begin()); + Operation *yield = block->getTerminator(); + Operation *reduceOp = yield->getOperand(0).getDefiningOp(); + if (!reduceOp || reduceOp->getNumOperands() != 2 || + reduceOp->getNumResults() != 1) + return std::nullopt; + auto intType = reduceOp->getResultTypes()[0].dyn_cast(); + if (!intType || intType.getWidth() > 32) + return std::nullopt; + if (reduceOp->getOperand(0) != block->getArgument(0) || + reduceOp->getOperand(1) != block->getArgument(1)) + return std::nullopt; + if (isa(reduceOp)) + return NVVM::ReduxKind::ADD; + if (isa(reduceOp)) + return NVVM::ReduxKind::AND; + if (isa(reduceOp)) + return NVVM::ReduxKind::OR; + if (isa(reduceOp)) + return NVVM::ReduxKind::XOR; + if (isa(reduceOp)) + return NVVM::ReduxKind::MIN; + if (isa(reduceOp)) + return NVVM::ReduxKind::UMIN; + if (isa(reduceOp)) + return NVVM::ReduxKind::MAX; + if (isa(reduceOp)) + return NVVM::ReduxKind::UMAX; + return std::nullopt; + } + // Reduce along op axis for elements that are in the same thread. The + // accumulated value is stored in accs. + void reduceWithinThreads( + ReduceOpHelper &helper, SmallVector> &srcValues, + std::map, SmallVector> &accs, + std::map, SmallVector> &indices, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + RankedTensorType operandType = op.getInputTypes()[0]; // Assumes offsets don't actually depend on type SmallVector> offset = - emitOffsetForLayout(srcLayout, srcTys[0]); - + emitOffsetForLayout(helper.getSrcLayout(), operandType); + unsigned srcElems = getTotalElemsPerThread(operandType); auto *combineOp = &op.getCombineOp(); - + auto srcIndices = + emitIndices(op.getLoc(), rewriter, helper.getSrcLayout(), operandType); // reduce within threads for (unsigned i = 0; i < srcElems; ++i) { SmallVector key = offset[i]; - key[axis] = 0; + key[op.getAxis()] = 0; bool isFirst = accs.find(key) == accs.end(); accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); if (isFirst) indices[key] = srcIndices[i]; } + } + + // Apply warp reduction across the given number of contiguous lanes using op + // region and the accumulator values as source. + void warpReduce(ConversionPatternRewriter &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce) const { + if (auto kind = matchReduxKind(op)) { + // Based on benchmarking on A100 redux op gives a speed up only when doing + // a single reduction (not partioned) and when the mask is static. + // Therefore we currently only enable it to reduce across all the lanes. + if (numLaneToReduce == 32) { + assert(acc.size() == 1); + Value mask = i32_val(0xFFFFFFFF); + // Even though we currently don't use redux for partitioned reduction + // the code below supports it in case we want to tweak the heuristic. + if (numLaneToReduce < 32) { + // For partitioned reduction we need to caluclate the mask so that + // each group of numLaneToReduce threads has the correct mask. + unsigned bitmask = (1 << numLaneToReduce) - 1; + Value threadId = getThreadId(rewriter, loc); + Value laneId = urem(threadId, i32_val(32)); + mask = shl(i32_val(bitmask), + and_(laneId, i32_val(~(numLaneToReduce - 1)))); + } + for (unsigned i = 0; i < acc.size(); ++i) { + unsigned bitwidth = acc[i].getType().cast().getWidth(); + if (bitwidth < 32) { + if (*kind == NVVM::ReduxKind::MIN || *kind == NVVM::ReduxKind::MAX) + acc[i] = sext(i32_ty, acc[i]); + else + acc[i] = zext(i32_ty, acc[i]); + } + acc[i] = rewriter.create(loc, acc[i].getType(), acc[0], + *kind, mask); + if (bitwidth < 32) + acc[i] = trunc(int_ty(bitwidth), acc[i]); + } + return; + } + } + + for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) { + SmallVector shfl(acc.size()); + unsigned shuffleIdx = N; +#ifdef USE_ROCM + auto srcTys = op.getInputTypes(); + auto inputTy = srcTys[0].cast(); + auto inMfma = + inputTy.getEncoding().dyn_cast(); + if (inMfma && inMfma.getIsTransposed()) { + assert(numLaneToReduce == 2 || numLaneToReduce == 4); + // for mfma 32x32 adjecant threads in y dimension in transposed MFMA layout are 32 + // apart: [[0 0 0 0 32 32 32 32 ...] [1 1 1 1 33 33 33 33 ...] ...]. + // for mfma 16x16 adjecant threads in y dimension in transposed MFMA layout are 16 + // apart: [[0 0 0 0 16 16 16 16 32 32 32 32 ...] [1 1 1 1 33 33 33 33 ...] ...]. + const int warpSize = 64; + shuffleIdx = warpSize / N / 2; + } +#endif + for (unsigned i = 0; i < acc.size(); ++i) { + shfl[i] = shflSync(loc, rewriter, acc[i], shuffleIdx); + } + accumulate(rewriter, op.getCombineOp(), acc, shfl, false); + } + } + + // Reduce across threads within each warp. + void + reduceWithinWarps(ReduceOpHelper &helper, + std::map, SmallVector> &accs, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData(); + for (auto it : accs) { + const SmallVector &key = it.first; + SmallVector &acc = accs[key]; + warpReduce(rewriter, op.getLoc(), acc, op, sizeIntraWarps); + } + } + + // Pack the accumualtor values and replace the reduce op with the result. + void packResults(ReduceOpHelper &helper, + std::map, SmallVector> &accs, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + unsigned axis = op.getAxis(); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (auto resultTy = + op.getResult()[i].getType().dyn_cast()) { + auto resultLayout = resultTy.getEncoding().cast(); + unsigned resultElems = getTotalElemsPerThread(resultTy); + SmallVector> resultOffset = + emitOffsetForLayout(resultLayout, resultTy); + SmallVector resultVals; + for (int j = 0; j < resultElems; j++) { + auto key = resultOffset[j]; + key.insert(key.begin() + axis, 0); + resultVals.push_back(accs[key][i]); + } + results[i] = getTypeConverter()->packLLElements(loc, resultVals, + rewriter, resultTy); + } else + results[i] = accs.begin()->second[i]; + } + rewriter.replaceOp(op, results); + } + + // Return the type of the shared memory pointer for operand i. + Type getElementPtrType(triton::ReduceOp op, int i) const { + auto ty = op.getInputTypes()[i].getElementType(); + auto llvmElemTy = getTypeConverter()->convertType(ty); + return LLVM::LLVMPointerType::get(llvmElemTy, 3); + } + void storeWarpReduceToSharedMemory( + ReduceOpHelper &helper, + std::map, SmallVector> &accs, + std::map, SmallVector> &indices, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); Value threadId = getThreadId(rewriter, loc); + auto srcLayout = helper.getSrcLayout(); unsigned wavefront_size = triton::gpu::getWarpSize(srcLayout); Value warpSize = i32_val(wavefront_size); Value warpId = udiv(threadId, warpSize); Value laneId = urem(threadId, warpSize); + auto srcShape = helper.getSrcShape(); + unsigned axis = op.getAxis(); + auto smemShapes = helper.getScratchConfigsFast(); auto threadsPerWarp = triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape); @@ -409,6 +530,7 @@ struct ReduceOpConversion delinearize(rewriter, loc, warpId, warpsPerCTA, order); #ifdef USE_ROCM + auto srcTys = op.getInputTypes(); auto inputTy = srcTys[0].cast(); auto inMfma = inputTy.getEncoding().dyn_cast(); @@ -428,76 +550,38 @@ struct ReduceOpConversion Value zero = i32_val(0); Value laneZero = icmp_eq(laneIdAxis, zero); - std::map, SmallVector> finalAccs; for (auto it : accs) { const SmallVector &key = it.first; SmallVector acc = it.second; - // Reduce within warps - for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) { - SmallVector shfl(op.getNumOperands()); - unsigned shuffleIdx = N; -#ifdef USE_ROCM - if (inMfma && inMfma.getIsTransposed()) { - assert(sizeIntraWarps == 2); - // Adjecant threads in y dimension in transposed MFMA layout are 32 - // apart: [[0 0 0 0 32 32 32 32 ...] [1 1 1 1 33 33 33 33 ...] ...]. - shuffleIdx = 32; - } -#endif - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - shfl[i] = shflSync(loc, rewriter, acc[i], shuffleIdx); - } - accumulate(rewriter, *combineOp, acc, shfl, false); - } - - if (isWarpSync) { - finalAccs[key] = acc; - continue; - } - SmallVector writeIdx = indices[key]; writeIdx[axis] = warpIdAxis; Value writeOffset = linearize(rewriter, loc, writeIdx, smemShapes[0], order); for (unsigned i = 0; i < op.getNumOperands(); ++i) { - Value writePtr = gep(elemPtrTys[i], smemBases[i], writeOffset); + auto elemPtrTy = getElementPtrType(op, i); + Value writePtr = gep(elemPtrTy, smemBases[i], writeOffset); storeShared(rewriter, loc, writePtr, acc[i], laneZero); } } + } - if (isWarpSync) { - SmallVector results(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - if (auto resultTy = - op.getResult()[i].getType().dyn_cast()) { - auto resultLayout = resultTy.getEncoding().cast(); - unsigned resultElems = getTotalElemsPerThread(resultTy); - SmallVector> resultOffset = - emitOffsetForLayout(resultLayout, resultTy); - SmallVector resultVals; - for (int j = 0; j < resultElems; j++) { - auto key = resultOffset[j]; - key.insert(key.begin() + axis, 0); - resultVals.push_back(finalAccs[key][i]); - } - results[i] = getTypeConverter()->packLLElements(loc, resultVals, - rewriter, resultTy); - } else - results[i] = finalAccs.begin()->second[i]; - } - rewriter.replaceOp(op, results); - return success(); - } - - barrier(); + // Load the reduction of each warp and accumulate them to a final value and + // store back to shared memory. + void accumulatePartialReductions(ReduceOpHelper &helper, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + auto srcLayout = helper.getSrcLayout(); + auto smemShapes = helper.getScratchConfigsFast(); + unsigned elems = product(smemShapes[0]); + unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData(); + Location loc = op.getLoc(); - // The second round of shuffle reduction - // now the problem size: sizeInterWarps, s1, s2, .. , sn - // where sizeInterWarps is 2^m - // - // Each thread needs to process: - // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(32); + Value laneId = urem(threadId, warpSize); + Value zero = i32_val(0); auto mod = op.getOperation()->getParentOfType(); unsigned numThreads = @@ -510,23 +594,18 @@ struct ReduceOpConversion // i32_val(sizeInerWarps)) SmallVector acc(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { - Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); + auto elemPtrTy = getElementPtrType(op, i); + Value readPtr = gep(elemPtrTy, smemBases[i], readOffset); acc[i] = load(readPtr); } - - for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) { - SmallVector shfl(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - shfl[i] = shflSync(loc, rewriter, acc[i], N); - } - accumulate(rewriter, *combineOp, acc, shfl, false); - } + warpReduce(rewriter, loc, acc, op, sizeInterWarps); // only the first thread in each sizeInterWarps is writing Value writeOffset = readOffset; SmallVector writePtrs(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { - writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset); + auto elemPtrTy = getElementPtrType(op, i); + writePtrs[i] = gep(elemPtrTy, smemBases[i], writeOffset); } Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); @@ -534,6 +613,9 @@ struct ReduceOpConversion icmp_eq(laneIdModSizeInterWarps, zero); Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero); + auto srcLayout = helper.getSrcLayout(); + unsigned wavefront_size = triton::gpu::getWarpSize(srcLayout); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { #if USE_ROCM // This barrier is known to be critical for Navi 2x/3x @@ -550,10 +632,17 @@ struct ReduceOpConversion readOffset = add(readOffset, i32_val(numThreads)); } } + } - barrier(); - - // set output values + // Load the final reduction from shared memory and replace the reduce result + // with it. + void loadReductionAndPackResult(ReduceOpHelper &helper, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + auto smemShapes = helper.getScratchConfigsFast(); + auto order = getOrder(helper.getSrcLayout()); SmallVector results(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { if (auto resultTy = @@ -567,10 +656,11 @@ struct ReduceOpConversion SmallVector resultVals(resultElems); for (size_t j = 0; j < resultElems; ++j) { SmallVector readIdx = resultIndices[j]; - readIdx.insert(readIdx.begin() + axis, i32_val(0)); + readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0)); Value readOffset = linearize(rewriter, loc, readIdx, smemShapes[0], order); - Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); + Value readPtr = + gep(getElementPtrType(op, i), smemBases[i], readOffset); resultVals[j] = load(readPtr); } @@ -582,6 +672,65 @@ struct ReduceOpConversion } } rewriter.replaceOp(op, results); + } + + // Use warp shuffle for reduction within warps and shared memory for data + // exchange across warps + LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + ReduceOpHelper helper(op); + assert(helper.isSupportedLayout() && + "Unexpected srcLayout in ReduceOpConversion"); + Location loc = op->getLoc(); + + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); + std::map, SmallVector> accs; + std::map, SmallVector> indices; + // First reduce all the values along axis within each thread. + reduceWithinThreads(helper, srcValues, accs, indices, rewriter); + + // Then reduce across threads within a warp. + reduceWithinWarps(helper, accs, rewriter); + + if (helper.isWarpSynchronous()) { + // If all the values to be reduced are within the same warp there is + // nothing left to do. + packResults(helper, accs, rewriter); + return success(); + } + + // Compute a shared memory base per operand. + auto smemShapes = helper.getScratchConfigsFast(); + unsigned elems = product(smemShapes[0]); + unsigned maxElems = std::max(elems, product(smemShapes[1])); + SmallVector smemBases(op.getNumOperands()); + smemBases[0] = + bitcast(getSharedMemoryBase(loc, rewriter, op.getOperation()), + getElementPtrType(op, 0)); + for (unsigned i = 1; i < op.getNumOperands(); ++i) { + smemBases[i] = bitcast(gep(getElementPtrType(op, i - 1), smemBases[i - 1], + i32_val(maxElems)), + getElementPtrType(op, i)); + } + storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter); + + sync(rewriter, loc, op); + + // The second round of shuffle reduction + // now the problem size: sizeInterWarps, s1, s2, .. , sn + // where sizeInterWarps is 2^m + // + // Each thread needs to process: + // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads + accumulatePartialReductions(helper, smemBases, rewriter); + + // We could avoid this barrier in some of the layouts, however this is not + // the general case. + // TODO: optimize the barrier in case the layouts are accepted. + sync(rewriter, loc, op); + + // set output values + loadReductionAndPackResult(helper, smemBases, rewriter); return success(); } @@ -589,9 +738,10 @@ struct ReduceOpConversion void populateReduceOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, - PatternBenefit benefit) { + int computeCapability, PatternBenefit benefit) { patterns.add(typeConverter, allocation, indexCacheInfo, - benefit); + computeCapability, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h index 677da86e30f2..225ddaec068a 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h @@ -8,8 +8,9 @@ using namespace mlir::triton; void populateReduceOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, - PatternBenefit benefit); + int computeCapability, PatternBenefit benefit); #endif diff --git a/lib/Conversion/TritonGPUToLLVM/RegReallocOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/RegReallocOpToLLVM.cpp new file mode 100644 index 000000000000..e498208428d6 --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/RegReallocOpToLLVM.cpp @@ -0,0 +1,43 @@ +#include "RegReallocOpToLLVM.h" + +using namespace mlir; +using namespace mlir::triton; + +struct RegAllocOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::RegAllocOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::RegAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getRegCount()); + return success(); + } +}; + +struct RegDeallocOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::RegDeallocOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::RegDeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getRegCount()); + return success(); + } +}; + +void populateRegReallocOpToLLVMPatterns( + TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, + const ModuleAllocation &allocation, PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + return; +} diff --git a/lib/Conversion/TritonGPUToLLVM/RegReallocOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/RegReallocOpToLLVM.h new file mode 100644 index 000000000000..ebbcabd26eab --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/RegReallocOpToLLVM.h @@ -0,0 +1,14 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_REGREALLOC_OP_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_REGREALLOC_OP_H + +#include "TritonGPUToLLVMBase.h" + +using namespace mlir; +using namespace mlir::triton; + +void populateRegReallocOpToLLVMPatterns( + TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, + const ModuleAllocation &allocation, PatternBenefit benefit); + +#endif diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp index 81c601ccc850..397abc1eae2a 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -323,6 +323,7 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, void populateScanOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit) { diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.h index 7a0bfbec5e3d..1249896708c8 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.h @@ -8,6 +8,7 @@ using namespace mlir::triton; void populateScanOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit); diff --git a/lib/Conversion/TritonGPUToLLVM/TensorPtrOpsToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TensorPtrOpsToLLVM.cpp new file mode 100644 index 000000000000..3aa43422669e --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/TensorPtrOpsToLLVM.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "TensorPtrOpsToLLVM.h" +using namespace mlir; +using namespace mlir::triton; + +struct MakeTensorPtrOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::MakeTensorPtrOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // struct { offset0, offset1, shape0, shape1, stride0, + // stride1, base_ptr}; + auto offsets = adaptor.getOffsets(); + auto shapes = adaptor.getShape(); + auto strides = adaptor.getStrides(); + auto base = adaptor.getBase(); + auto result = op.getResult(); + + SmallVector elems; + for (auto offset : offsets) + elems.push_back(offset); + for (auto shape : shapes) + elems.push_back(shape); + for (auto stride : strides) + elems.push_back(stride); + + elems.push_back(base); + + auto newValue = getTypeConverter()->packLLElements( + op.getLoc(), elems, rewriter, result.getType()); + rewriter.replaceOp(op, newValue); + return success(); + } +}; + +struct AdvanceOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::AdvanceOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::AdvanceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // struct { offset0, offset1, shape0, shape1, stride0, + // stride1, base_ptr}; + auto loc = op.getLoc(); + auto ptrType = op.getPtr().getType(); + auto tensorPtr = adaptor.getPtr(); + + auto offsets = adaptor.getOffsets(); + auto elems = + getTypeConverter()->unpackLLElements(loc, tensorPtr, rewriter, ptrType); + + SmallVector newOffsets; + + for (auto [offset, oldOffset] : llvm::zip_first(offsets, elems)) { + newOffsets.push_back((add(offset, oldOffset))); + } + + for (size_t i = 0; i < newOffsets.size(); ++i) { + elems[i] = newOffsets[i]; + } + + auto newValue = getTypeConverter()->packLLElements(op.getLoc(), elems, + rewriter, ptrType); + rewriter.replaceOp(op, newValue); + return success(); + } +}; + +void populateTensorPtrOpsToLLVMPatterns( + TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, + ModuleAllocation &allocation, PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + return; +} diff --git a/lib/Conversion/TritonGPUToLLVM/TensorPtrOpsToLLVM.h b/lib/Conversion/TritonGPUToLLVM/TensorPtrOpsToLLVM.h new file mode 100644 index 000000000000..2bf5eb082b88 --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/TensorPtrOpsToLLVM.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TENSOR_PTR_OPS_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TENSOR_PTR_OPS_H + +#include "TritonGPUToLLVMBase.h" + +using namespace mlir; +using namespace mlir::triton; + +void populateTensorPtrOpsToLLVMPatterns( + TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, + ModuleAllocation &allocation, PatternBenefit benefit); + +#endif diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 7625064ab062..d72f2684217c 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -6,6 +6,7 @@ using namespace mlir; using namespace mlir::triton; using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::LLVM::getSRegValue; using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::SharedEncodingAttr; @@ -445,6 +446,8 @@ struct GetProgramIdOpConversion LogicalResult matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + +#ifdef USE_ROCM Location loc = op->getLoc(); assert(op.getAxisAsInt() < 3); @@ -452,8 +455,25 @@ struct GetProgramIdOpConversion rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxisAsInt()]); rewriter.replaceOpWithNewOp(op, i32_ty, blockId); return success(); - } +#else + // It is not easy to get the compute capability here, so we use numCTAs to + // decide the semantic of GetProgramIdOp. If numCTAs = 1, then + // GetProgramIdOp is converted to "%ctaid", otherwise it is converted to + // "%clusterid". + auto moduleOp = op->getParentOfType(); + assert(moduleOp && "Parent ModuleOp not found for GetProgramIdOp"); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp); + + Location loc = op->getLoc(); + assert(op.getAxisAsInt() < 3); + std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid."; + sreg.append(1, 'x' + op.getAxisAsInt()); // 0 -> 'x', 1 -> 'y', 2 -> 'z' + Value programId = getSRegValue(rewriter, loc, sreg); + rewriter.replaceOp(op, programId); + return success(); +#endif + } static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x, mlir::gpu::Dimension::y, mlir::gpu::Dimension::z}; @@ -467,13 +487,14 @@ struct GetNumProgramsOpConversion LogicalResult matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - assert(op.getAxis() < 3); #ifdef USE_ROCM - // Seem like GridDimOp returns the number of threads (not the number of - // workgroups) in a kernel (a bug in llvm https://reviews.llvm.org/D156009), - // so as a workaround here, we divide by the number of threads + + Location loc = op->getLoc(); + assert(op.getAxis() < 3); + // Seem like GridDimOp returns the number of threads (not the number of + // workgroups) in a kernel (a bug in llvm https://reviews.llvm.org/D156009), + // so as a workaround here, we divide by the number of threads // per workgroup to get the number of workgroups in a kernel. // TODO: when we do upstream to include llvm fix, we can remove this workaround // The unit test added in this PR can guarantee that. @@ -482,15 +503,28 @@ struct GetNumProgramsOpConversion Value threadsPerBlock = rewriter.create<::mlir::gpu::BlockDimOp>(loc, dims[op.getAxis()]); Value threadNumPerGrid = rewriter.create(loc, i32_ty, threadsPerGrid); - Value threadNumPerBlock = rewriter.create(loc, i32_ty, threadsPerBlock); + Value threadNumPerBlock = rewriter.create(loc, i32_ty, threadsPerBlock); rewriter.replaceOpWithNewOp(op, threadNumPerGrid, threadNumPerBlock); + return success(); #else - Value blockId = - rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]); - rewriter.replaceOpWithNewOp(op, i32_ty, blockId); -#endif // USE_ROCM + // It is not easy to get the compute capability here, so we use numCTAs to + // decide the semantic of GetNumProgramsOp. If numCTAs = 1, then + // GetNumProgramsOp is converted to "%nctaid", otherwise it is converted to + // "%nclusterid". + auto moduleOp = op->getParentOfType(); + assert(moduleOp && "Parent ModuleOp not found for GetProgramIdOp"); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp); + Location loc = op->getLoc(); + assert(op.getAxis() < 3); + std::string sreg = numCTAs == 1 ? "%nctaid." : "%nclusterid."; + sreg.append(1, 'x' + op.getAxis()); // 0 -> 'x', 1 -> 'y', 2 -> 'z' + + Value numPrograms = getSRegValue(rewriter, loc, sreg); + rewriter.replaceOp(op, numPrograms); return success(); + +#endif } static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x, @@ -498,6 +532,37 @@ struct GetNumProgramsOpConversion mlir::gpu::Dimension::z}; }; +// TODO[goostavz]: GetThreadIdOp/GetClusterCTAIdOp is a temporary solution +// before async dialect is done. These concepts should appear in ttgpu +// level, and they are planned to be deprecated along with ttgpu.mbarrier_xxx +// ops. +struct GetThreadIdOpConversion : public ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::GetThreadIdOp> { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::GetThreadIdOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::GetThreadIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, getThreadId(rewriter, op->getLoc())); + return success(); + } +}; + +struct GetClusterCTAIdOpConversion + : public ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::GetClusterCTAIdOp> { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::GetClusterCTAIdOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::GetClusterCTAIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, getClusterCTAId(rewriter, op->getLoc())); + return success(); + } +}; + struct AddPtrOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< @@ -551,7 +616,8 @@ struct AllocTensorOpConversion getTypeConverter()->convertType(resultTy.getElementType()); auto elemPtrTy = ptr_ty(llvmElemTy, 3); smemBase = bitcast(smemBase, elemPtrTy); - auto order = resultTy.getEncoding().cast().getOrder(); + auto sharedLayout = resultTy.getEncoding().cast(); + auto order = sharedLayout.getOrder(); // Workaround for 3D tensors // TODO: we need to modify the pipeline pass to give a proper shared // encoding to 3D tensors @@ -561,8 +627,9 @@ struct AllocTensorOpConversion else newOrder = SmallVector(order.begin(), order.end()); - auto smemObj = SharedMemoryObject(smemBase, resultTy.getShape(), newOrder, - loc, rewriter); + auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape()); + auto smemObj = + SharedMemoryObject(smemBase, shapePerCTA, newOrder, loc, rewriter); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); rewriter.replaceOp(op, retVal); return success(); @@ -605,7 +672,7 @@ struct ExtractSliceOpConversion // newShape = rank_reduce(shape) // Triton only supports static tensor sizes SmallVector strideVals; - for (auto i = 0; i < op.static_sizes().size(); ++i) { + for (auto i = 0; i < op.getStaticSizes().size(); ++i) { if (op.getStaticSize(i) == 1) { offsetVals.erase(offsetVals.begin() + i); } else { @@ -665,6 +732,49 @@ struct AsyncCommitGroupOpConversion } }; +struct AsyncBulkWaitOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::gpu::AsyncBulkWaitOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::AsyncBulkWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + PTXBuilder ptxBuilder; + auto &asyncBulkWaitOp = *ptxBuilder.create<>("cp.async.bulk.wait_group"); + auto num = op->getAttrOfType("num").getInt(); + asyncBulkWaitOp(ptxBuilder.newConstantOperand(num)); + + auto ctx = op.getContext(); + auto loc = op.getLoc(); + auto voidTy = void_ty(ctx); + ptxBuilder.launch(rewriter, loc, voidTy); + + // Safe to remove the op since it doesn't have any return value. + rewriter.eraseOp(op); + return success(); + } +}; + +struct AsyncBulkCommitGroupOpConversion + : public ConvertTritonGPUOpToLLVMPattern< + triton::gpu::AsyncBulkCommitGroupOp> { + using ConvertTritonGPUOpToLLVMPattern< + triton::gpu::AsyncBulkCommitGroupOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::AsyncBulkCommitGroupOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + PTXBuilder ptxBuilder; + ptxBuilder.create<>("cp.async.bulk.commit_group")->operator()(); + ptxBuilder.launch(rewriter, op.getLoc(), void_ty(op.getContext())); + // Safe to remove the op since it doesn't have any return value. + rewriter.eraseOp(op); + return success(); + } +}; + namespace mlir { namespace LLVM { @@ -690,6 +800,7 @@ void vprintf_array(Value thread, ArrayRef arr, std::string info, void populateTritonGPUToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &moduleAllocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit) { @@ -698,12 +809,15 @@ void populateTritonGPUToLLVMPatterns( benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); - patterns.add(typeConverter, moduleAllocation, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, indexCacheInfo, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h index 50ba1355d75e..9019073584c0 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h @@ -8,6 +8,7 @@ using namespace mlir::triton; void populateTritonGPUToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 689dc498a388..df3e877d1283 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -11,17 +11,31 @@ #include "Utility.h" #include "mlir/IR/TypeUtilities.h" #include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Target/PTX/TmaMetadata.h" #include + +#define DEBUG_TYPE "ttgpu_to_llvm" + +constexpr ::llvm::StringLiteral kAttrNumTMALoadDescsName = + "triton_gpu.num-tma-load"; +constexpr ::llvm::StringLiteral kAttrNumTMAStoreDescsName = + "triton_gpu.num-tma-store"; using namespace mlir; using namespace mlir::triton; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::SharedMemoryObject; using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::CTALayoutAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::MfmaEncodingAttr; using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr; +using ::mlir::triton::gpu::TMAMetadataTy; +namespace ttng = ::mlir::triton::nvidia_gpu; + +typedef DenseMap TensorPtrMapT; namespace mlir { namespace LLVM { @@ -142,36 +156,39 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern { } }; -using IndexCacheKeyT = std::pair; +struct IndexCacheKeyT { + Attribute layout; + RankedTensorType type; + bool withCTAOffset; +}; struct CacheKeyDenseMapInfo { static IndexCacheKeyT getEmptyKey() { auto *pointer = llvm::DenseMapInfo::getEmptyKey(); - return std::make_pair( - mlir::Attribute(static_cast(pointer)), - RankedTensorType{}); + return {mlir::Attribute(static_cast(pointer)), + RankedTensorType{}, true}; } static IndexCacheKeyT getTombstoneKey() { auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); auto tombstone = llvm::DenseMapInfo::getTombstoneKey(); - return std::make_pair( - mlir::Attribute(static_cast(pointer)), - tombstone); + return {mlir::Attribute(static_cast(pointer)), + tombstone, true}; } static unsigned getHashValue(IndexCacheKeyT key) { - auto shape = key.second.getShape(); - return llvm::hash_combine(mlir::hash_value(key.first), - mlir::hash_value(key.second)); + return llvm::hash_combine(mlir::hash_value(key.layout), + mlir::hash_value(key.type), + llvm::hash_value(key.withCTAOffset)); } static bool isEqual(IndexCacheKeyT LHS, IndexCacheKeyT RHS) { - return LHS == RHS; + return LHS.layout == RHS.layout && LHS.type == RHS.type && + LHS.withCTAOffset == RHS.withCTAOffset; } }; class ConvertTritonGPUOpToLLVMPatternBase { public: // Two levels of value cache in emitting indices calculation: - // Key: pair + // Key: {layout, shape, withCTAOffset} struct IndexCacheInfo { DenseMap, CacheKeyDenseMapInfo> *baseIndexCache; @@ -199,6 +216,12 @@ class ConvertTritonGPUOpToLLVMPatternBase { : converter(&typeConverter), allocation(&allocation), indexCacheInfo(indexCacheInfo) {} + explicit ConvertTritonGPUOpToLLVMPatternBase( + TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation, + TMAMetadataTy *tmaMetadata) + : converter(&typeConverter), allocation(&allocation), + tmaMetadata(tmaMetadata) {} + TritonGPUToLLVMTypeConverter *getTypeConverter() const { return converter; } static Value @@ -218,12 +241,32 @@ class ConvertTritonGPUOpToLLVMPatternBase { return llvmStruct; } - Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const { - auto tid = rewriter.create<::mlir::gpu::ThreadIdOp>( + // Returns CTA level thread idx + Value getThreadIdInCTA(ConversionPatternRewriter &rewriter, + Location loc) const { + Value tid = rewriter.create<::mlir::gpu::ThreadIdOp>( loc, ::mlir::gpu::Dimension::x); return rewriter.create(loc, i32_ty, tid); } + // Returns CTA level thread idx for not ws mode. + // Returns agent level thread idx for ws mode. + Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const { + Value tid = getThreadIdInCTA(rewriter, loc); + auto mod = rewriter.getBlock()->getParent()->getParentOfType(); + if (ttng::TritonNvidiaGPUDialect::getWSSupportedAttr(mod)) { + Value _128 = rewriter.create(loc, 128, 32); + tid = rewriter.create(loc, tid, _128); + } + return tid; + } + + Value getClusterCTAId(ConversionPatternRewriter &rewriter, + Location loc) const { + return rewriter.create( + loc, rewriter.getI32Type()); + } + // ----------------------------------------------------------------------- // Shared memory utilities // ----------------------------------------------------------------------- @@ -260,7 +303,7 @@ class ConvertTritonGPUOpToLLVMPatternBase { // for all indices (row, col) of `srcEncoding` such that idx % inVec = 0, // the pointer: ptr[(row, col)] = base + (rowOff * strides[ord[1]] + // colOff) where : - // compute phase = (row // perPhase) % maxPhase + // phase = (row // perPhase) % maxPhase // rowOff = row // colOff = colOffSwizzled + colOffOrdered // colOffSwizzled = ((col // outVec) ^ phase) * outVec @@ -281,60 +324,90 @@ class ConvertTritonGPUOpToLLVMPatternBase { // then (x + y) XOR z = 0byyyyxxxx XOR 0b00000zzzz = (x XOR z) + y // This means that we can use some immediate offsets for shared memory // operations. - auto dstPtrTy = ptr_ty(resElemTy, 3); + resElemTy = getTypeConverter()->convertType(resElemTy); + auto dstPtrTy = ptr_ty(getTypeConverter()->convertType(resElemTy), 3); auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides); Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset); auto srcEncoding = srcTy.getEncoding(); auto srcShape = srcTy.getShape(); + auto srcShapePerCTA = triton::gpu::getShapePerCTA(srcTy); unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); // swizzling params as described in TritonGPUAttrDefs.td unsigned outVec = resSharedLayout.getVec(); unsigned perPhase = resSharedLayout.getPerPhase(); unsigned maxPhase = resSharedLayout.getMaxPhase(); - // order + // Order auto inOrder = triton::gpu::getOrder(srcEncoding); auto outOrder = triton::gpu::getOrder(resSharedLayout); - // tensor indices held by the current thread, as LLVM values - auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy); - // return values + // Tensor indices held by the current thread, as LLVM values + auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy, false); + // Swizzling with leading offsets (e.g. Hopper GMMA) + unsigned swizzlingByteWidth = 0; + if (resSharedLayout.getHasLeadingOffset()) { + if (perPhase == 4 && maxPhase == 2) + swizzlingByteWidth = 32; + else if (perPhase == 2 && maxPhase == 4) + swizzlingByteWidth = 64; + else if (perPhase == 1 && maxPhase == 8) + swizzlingByteWidth = 128; + else + llvm::report_fatal_error("Unsupported shared layout."); + } + unsigned numElemsPerSwizzlingRow = + swizzlingByteWidth * 8 / resElemTy.getIntOrFloatBitWidth(); + Value numElemsPerSwizzlingRowVal = i32_val(numElemsPerSwizzlingRow); + unsigned leadingDimOffset = + numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]]; + Value leadingDimOffsetVal = i32_val(leadingDimOffset); + // Return values DenseMap ret; // cache for non-immediate offsets DenseMap cacheCol, cacheRow; unsigned minVec = std::min(outVec, inVec); for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) { - // extract multi dimensional index for current element + Value offset = i32_val(0); + // Extract multi dimensional index for current element auto idx = srcIndices[elemIdx]; Value idxCol = idx[outOrder[0]]; // contiguous dimension Value idxRow = idx[outOrder[1]]; // discontiguous dimension Value strideCol = srcStrides[outOrder[0]]; Value strideRow = srcStrides[outOrder[1]]; + // compute phase = (row // perPhase) % maxPhase + Value phase = urem(udiv(idxRow, i32_val(perPhase)), i32_val(maxPhase)); // extract dynamic/static offset for immediate offsetting unsigned immedateOffCol = 0; - if (auto add = dyn_cast_or_null(idxCol.getDefiningOp())) - if (auto _cst = dyn_cast_or_null( - add.getRhs().getDefiningOp())) { - unsigned cst = - _cst.getValue().cast().getValue().getSExtValue(); - unsigned key = cst % (outVec * maxPhase); - cacheCol.insert({key, idxCol}); - idxCol = cacheCol[key]; - immedateOffCol = cst / (outVec * maxPhase) * (outVec * maxPhase); - } - // extract dynamic/static offset for immediate offsetting unsigned immedateOffRow = 0; - if (auto add = dyn_cast_or_null(idxRow.getDefiningOp())) - if (auto _cst = dyn_cast_or_null( - add.getRhs().getDefiningOp())) { - unsigned cst = - _cst.getValue().cast().getValue().getSExtValue(); - unsigned key = cst % (perPhase * maxPhase); - cacheRow.insert({key, idxRow}); - idxRow = cacheRow[key]; - immedateOffRow = cst / (perPhase * maxPhase) * (perPhase * maxPhase); - } - // compute phase = (row // perPhase) % maxPhase - Value phase = urem(udiv(idxRow, i32_val(perPhase)), i32_val(maxPhase)); + if (leadingDimOffset) { + // hopper + offset = + mul(udiv(idxCol, numElemsPerSwizzlingRowVal), leadingDimOffsetVal); + // Shrink by swizzling blocks + idxCol = urem(idxCol, numElemsPerSwizzlingRowVal); + strideRow = numElemsPerSwizzlingRowVal; + } else { + if (auto add = dyn_cast_or_null(idxCol.getDefiningOp())) + if (auto _cst = dyn_cast_or_null( + add.getRhs().getDefiningOp())) { + unsigned cst = + _cst.getValue().cast().getValue().getSExtValue(); + unsigned key = cst % (outVec * maxPhase); + cacheCol.insert({key, idxCol}); + idxCol = cacheCol[key]; + immedateOffCol = cst / (outVec * maxPhase) * (outVec * maxPhase); + } + if (auto add = dyn_cast_or_null(idxRow.getDefiningOp())) + if (auto _cst = dyn_cast_or_null( + add.getRhs().getDefiningOp())) { + unsigned cst = + _cst.getValue().cast().getValue().getSExtValue(); + unsigned key = cst % (perPhase * maxPhase); + cacheRow.insert({key, idxRow}); + idxRow = cacheRow[key]; + immedateOffRow = + cst / (perPhase * maxPhase) * (perPhase * maxPhase); + } + } // row offset is simply row index Value rowOff = mul(idxRow, strideRow); // because swizzling happens at a granularity of outVec, we need to @@ -348,7 +421,7 @@ class ConvertTritonGPUOpToLLVMPatternBase { colOffOrdered = mul(colOffOrdered, i32_val(minVec)); Value colOff = add(colOffSwizzled, colOffOrdered); // compute non-immediate offset - Value offset = add(rowOff, mul(colOff, strideCol)); + offset = add(offset, add(rowOff, mul(colOff, strideCol))); Value currPtr = gep(dstPtrTy, dstPtrBase, offset); // compute immediate offset Value immedateOff = @@ -478,7 +551,7 @@ class ConvertTritonGPUOpToLLVMPatternBase { auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout); auto order = triton::gpu::getOrder(layout); - auto shapePerCTA = triton::gpu::getShapePerCTA(layout, shape); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); Value warpSize = i32_val(triton::gpu::getWarpSize(layout)); Value laneId = urem(tid, warpSize); Value warpId = udiv(tid, warpSize); @@ -488,7 +561,7 @@ class ConvertTritonGPUOpToLLVMPatternBase { delinearize(rewriter, loc, laneId, threadsPerWarp, order); for (unsigned dim = 0; dim < rank; ++dim) { // if there is no data replication across threads on this dimension - if (shape[dim] >= shapePerCTA[dim]) + if (shape[dim] >= shapePerCTATile[dim]) continue; // Otherwise, we need to mask threads that will replicate data on this // dimension. Calculate the thread index on this dimension for the CTA @@ -498,6 +571,28 @@ class ConvertTritonGPUOpToLLVMPatternBase { mask = and_(mask, icmp_slt(mul(threadDim, i32_val(sizePerThread[dim])), i32_val(shape[dim]))); } + // Do not write duplicated data when multicast is enabled + if (triton::gpu::getNumCTAs(layout) > 1) { + auto _0 = i32_val(0); + auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout); + auto CTASplitNum = triton::gpu::getCTASplitNum(layout); + auto CTAOrder = triton::gpu::getCTAOrder(layout); + + auto clusterCTAId = getClusterCTAId(rewriter, loc); + auto multiDimClusterCTAId = + delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); + + for (unsigned dim = 0; dim < rank; ++dim) { + // Skip when multicast is not enabled in this dimension + if (CTAsPerCGA[dim] == CTASplitNum[dim]) + continue; + // This wrapping rule must be consistent with emitCTAOffsetForLayout + unsigned splitNum = std::min(shape[dim], CTASplitNum[dim]); + multiDimClusterCTAId[dim] = + urem(multiDimClusterCTAId[dim], i32_val(splitNum)); + mask = and_(mask, icmp_eq(multiDimClusterCTAId[dim], _0)); + } + } } else { // If the tensor is not ranked, then it is a scalar and only thread 0 can // write @@ -536,13 +631,48 @@ class ConvertTritonGPUOpToLLVMPatternBase { // Get offsets / indices for any layout // ----------------------------------------------------------------------- + SmallVector emitCTAOffsetForLayout(Location loc, + ConversionPatternRewriter &rewriter, + Attribute layout, + ArrayRef shape) const { + unsigned rank = shape.size(); + SmallVector CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout); + SmallVector CTASplitNum = triton::gpu::getCTASplitNum(layout); + SmallVector CTAOrder = triton::gpu::getCTAOrder(layout); + SmallVector shapePerCTA = + triton::gpu::getShapePerCTA(CTASplitNum, shape); + + // Delinearize clusterCTAId + Value clusterCTAId = getClusterCTAId(rewriter, loc); + SmallVector multiDimClusterCTAId = + delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); + + // CTA Wrapping + for (unsigned i = 0; i < rank; ++i) { + // This wrapping rule must be consistent with getShapePerCTA + unsigned splitNum = std::min(shape[i], CTASplitNum[i]); + multiDimClusterCTAId[i] = + urem(multiDimClusterCTAId[i], i32_val(splitNum)); + } + + SmallVector CTAOffset(rank); + for (unsigned i = 0; i < rank; ++i) + CTAOffset[i] = mul(multiDimClusterCTAId[i], i32_val(shapePerCTA[i])); + + return CTAOffset; + } + SmallVector emitBaseIndexForLayout(Location loc, ConversionPatternRewriter &rewriter, Attribute layout, - RankedTensorType type) const { - IndexCacheKeyT key = std::make_pair(layout, type); + RankedTensorType type, + bool withCTAOffset) const { + auto shape = type.getShape(); + IndexCacheKeyT key{layout, type, withCTAOffset}; auto cache = indexCacheInfo.baseIndexCache; auto insertPt = indexCacheInfo.indexInsertPoint; + + SmallVector baseIndex; if (cache && cache->count(key) > 0) { return cache->lookup(key); } else { @@ -551,13 +681,15 @@ class ConvertTritonGPUOpToLLVMPatternBase { restoreInsertionPointIfSet(insertPt, rewriter); SmallVector result; if (auto blockedLayout = layout.dyn_cast()) { - result = - emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, type); + result = emitBaseIndexWithinCTAForBlockedLayout(loc, rewriter, + blockedLayout, type); } else if (auto mmaLayout = layout.dyn_cast()) { if (mmaLayout.isVolta()) - result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, type); - if (mmaLayout.isAmpere()) - result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, type); + result = emitBaseIndexWithinCTAForMmaLayoutV1(loc, rewriter, + mmaLayout, type); + if (mmaLayout.isAmpere() || mmaLayout.isHopper()) + result = emitBaseIndexWithinCTAForMmaLayoutV2V3(loc, rewriter, + mmaLayout, type); } else if (auto mfmaLayout = layout.dyn_cast()) { result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type); } else if (auto sliceLayout = layout.dyn_cast()) { @@ -565,11 +697,20 @@ class ConvertTritonGPUOpToLLVMPatternBase { auto parentShape = sliceLayout.paddedShape(type.getShape()); RankedTensorType parentTy = RankedTensorType::get( parentShape, type.getElementType(), parentLayout); - result = emitBaseIndexForLayout(loc, rewriter, parentLayout, parentTy); + result = emitBaseIndexForLayout(loc, rewriter, parentLayout, parentTy, + withCTAOffset); result.erase(result.begin() + sliceLayout.getDim()); + // CTAOffset has been added in emitBaseIndexForLayout of parentLayout + return result; } else { llvm_unreachable("unsupported emitBaseIndexForLayout"); } + if (withCTAOffset) { + auto CTAOffset = emitCTAOffsetForLayout(loc, rewriter, layout, shape); + assert(CTAOffset.size() == result.size() && "Rank mismatch"); + for (unsigned k = 0; k < result.size(); ++k) + result[k] = add(result[k], CTAOffset[k]); + } if (cache) { cache->insert(std::make_pair(key, result)); *insertPt = rewriter.saveInsertionPoint(); @@ -587,6 +728,8 @@ class ConvertTritonGPUOpToLLVMPatternBase { return emitOffsetForMmaLayoutV1(mmaLayout, type); if (mmaLayout.isAmpere()) return emitOffsetForMmaLayoutV2(mmaLayout, type); + if (mmaLayout.isHopper()) + return emitOffsetForMmaLayoutV3(mmaLayout, type); } if (auto mfmaLayout = layout.dyn_cast()) { return emitOffsetForMfmaLayout(mfmaLayout, type); @@ -600,15 +743,17 @@ class ConvertTritonGPUOpToLLVMPatternBase { void emitMfmaOffsetForCTA(const MfmaEncodingAttr &mfmaLayout, SmallVector> &offsets, unsigned ctaOffsetX, unsigned ctaOffsetY) const { + auto nonKDim = mfmaLayout.getNonKDim(); // MFMA output tile consists of repeated "dot operand B" layout groups along // row axis. This variable defines number of these groups. - const unsigned numGroups = 4; + const unsigned numGroups = (nonKDim == 32 ? 4 : 1); const unsigned elemsPerThreadPerGroup = 4; auto warpSize = getWarpSize(mfmaLayout); assert(warpSize == 64); - auto shapePerCta = getShapePerCTA(mfmaLayout); + auto shapePerCta = getShapePerCTATile(mfmaLayout); for (unsigned block = 0; block < numGroups; block++) { - unsigned rowOrColOffset = block * elemsPerThreadPerGroup * warpSize / 32; + unsigned rowOrColOffset = + block * elemsPerThreadPerGroup * warpSize / nonKDim; for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) { if (mfmaLayout.getIsTransposed()) { offsets.push_back( @@ -627,11 +772,10 @@ class ConvertTritonGPUOpToLLVMPatternBase { // ----------------------------------------------------------------------- // Emit indices // ----------------------------------------------------------------------- - SmallVector> emitIndices(Location loc, - ConversionPatternRewriter &b, - Attribute layout, - RankedTensorType type) const { - IndexCacheKeyT key(layout, type); + SmallVector> + emitIndices(Location loc, ConversionPatternRewriter &b, Attribute layout, + RankedTensorType type, bool withCTAOffset = true) const { + IndexCacheKeyT key{layout, type, withCTAOffset}; auto cache = indexCacheInfo.indexCache; auto insertPt = indexCacheInfo.indexInsertPoint; if (cache && cache->count(key) > 0) { @@ -642,13 +786,17 @@ class ConvertTritonGPUOpToLLVMPatternBase { restoreInsertionPointIfSet(insertPt, b); SmallVector> result; if (auto blocked = layout.dyn_cast()) { - result = emitIndicesForDistributedLayout(loc, b, blocked, type); + result = emitIndicesForDistributedLayout(loc, b, blocked, type, + withCTAOffset); } else if (auto mma = layout.dyn_cast()) { - result = emitIndicesForDistributedLayout(loc, b, mma, type); + result = + emitIndicesForDistributedLayout(loc, b, mma, type, withCTAOffset); } else if (auto mfma = layout.dyn_cast()) { - result = emitIndicesForDistributedLayout(loc, b, mfma, type); + result = + emitIndicesForDistributedLayout(loc, b, mfma, type, withCTAOffset); } else if (auto slice = layout.dyn_cast()) { - result = emitIndicesForDistributedLayout(loc, b, slice, type); + result = + emitIndicesForDistributedLayout(loc, b, slice, type, withCTAOffset); } else { llvm_unreachable( "emitIndices for layouts other than blocked & slice not " @@ -678,19 +826,20 @@ class ConvertTritonGPUOpToLLVMPatternBase { // Blocked layout indices // ----------------------------------------------------------------------- - // Get an index-base for each dimension for a \param blocked_layout. - SmallVector emitBaseIndexForBlockedLayout( + // Get an index-base for each dimension for a \param blockedLayout. + SmallVector emitBaseIndexWithinCTAForBlockedLayout( Location loc, ConversionPatternRewriter &rewriter, - const BlockedEncodingAttr &blocked_layout, RankedTensorType type) const { + const BlockedEncodingAttr &blockedLayout, RankedTensorType type) const { auto shape = type.getShape(); Value threadId = getThreadId(rewriter, loc); - Value warpSize = i32_val(triton::gpu::getWarpSize(blocked_layout)); + Value warpSize = i32_val(triton::gpu::getWarpSize(blockedLayout)); Value laneId = urem(threadId, warpSize); Value warpId = udiv(threadId, warpSize); - auto sizePerThread = blocked_layout.getSizePerThread(); - auto threadsPerWarp = blocked_layout.getThreadsPerWarp(); - auto warpsPerCTA = blocked_layout.getWarpsPerCTA(); - auto order = blocked_layout.getOrder(); + auto sizePerThread = blockedLayout.getSizePerThread(); + auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); + auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); + auto order = blockedLayout.getOrder(); + auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape); unsigned rank = shape.size(); // delinearize threadId to get the base index @@ -702,10 +851,10 @@ class ConvertTritonGPUOpToLLVMPatternBase { SmallVector multiDimBase(rank); for (unsigned k = 0; k < rank; ++k) { // Wrap around multiDimWarpId/multiDimThreadId in case - // shape[k] > shapePerCTA[k] + // shapePerCTATile[k] > shapePerCTA[k] auto maxWarps = - ceil(shape[k], sizePerThread[k] * threadsPerWarp[k]); - auto maxThreads = ceil(shape[k], sizePerThread[k]); + ceil(shapePerCTA[k], sizePerThread[k] * threadsPerWarp[k]); + auto maxThreads = ceil(shapePerCTA[k], sizePerThread[k]); multiDimWarpId[k] = urem(multiDimWarpId[k], i32_val(maxWarps)); multiDimThreadId[k] = urem(multiDimThreadId[k], i32_val(maxThreads)); // multiDimBase[k] = (multiDimThreadId[k] + @@ -728,16 +877,17 @@ class ConvertTritonGPUOpToLLVMPatternBase { auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); auto order = blockedLayout.getOrder(); + auto shapePerCTATile = getShapePerCTATile(blockedLayout); + auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape); unsigned rank = shape.size(); - SmallVector shapePerCTA = getShapePerCTA(blockedLayout); SmallVector tilesPerDim(rank); for (unsigned k = 0; k < rank; ++k) - tilesPerDim[k] = ceil(shape[k], shapePerCTA[k]); + tilesPerDim[k] = ceil(shapePerCTA[k], shapePerCTATile[k]); SmallVector> offset(rank); for (unsigned k = 0; k < rank; ++k) { - // 1 block in minimum if shape[k] is less than shapePerCTA[k] + // 1 CTA tile in minimum if shapePerCTA[k] is less than shapePerCTATile[k] for (unsigned blockOffset = 0; blockOffset < tilesPerDim[k]; ++blockOffset) for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset) @@ -777,12 +927,10 @@ class ConvertTritonGPUOpToLLVMPatternBase { // Mma layout indices // ----------------------------------------------------------------------- - SmallVector - emitBaseIndexForMmaLayoutV1(Location loc, ConversionPatternRewriter &rewriter, - const MmaEncodingAttr &mmaLayout, - RankedTensorType type) const { + SmallVector emitBaseIndexWithinCTAForMmaLayoutV1( + Location loc, ConversionPatternRewriter &rewriter, + const MmaEncodingAttr &mmaLayout, RankedTensorType type) const { auto shape = type.getShape(); - auto wpt = mmaLayout.getWarpsPerCTA(); static constexpr std::array fpw{{2, 2, 1}}; auto [isARow, isBRow, isAVec4, isBVec4, _] = @@ -900,28 +1048,75 @@ class ConvertTritonGPUOpToLLVMPatternBase { return ret; } - SmallVector - emitBaseIndexForMmaLayoutV2(Location loc, ConversionPatternRewriter &rewriter, - const MmaEncodingAttr &mmaLayout, - RankedTensorType type) const { + SmallVector> + emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout, + RankedTensorType type) const { auto shape = type.getShape(); - auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); - assert(warpsPerCTA.size() == 2); + auto shapePerCTA = getShapePerCTA(mmaLayout, shape); + SmallVector> ret; + + for (unsigned i = 0; i < shapePerCTA[0]; + i += getShapePerCTATile(mmaLayout)[0]) { + for (unsigned j = 0; j < shapePerCTA[1]; + j += getShapePerCTATile(mmaLayout)[1]) { + ret.push_back({i, j}); + ret.push_back({i, j + 1}); + ret.push_back({i + 8, j}); + ret.push_back({i + 8, j + 1}); + } + } + return ret; + } + + SmallVector emitBaseIndexWithinCTAForMmaLayoutV2V3( + Location loc, ConversionPatternRewriter &rewriter, + const MmaEncodingAttr &mmaLayout, RankedTensorType type) const { + auto shape = type.getShape(); + auto _warpsPerCTA = mmaLayout.getWarpsPerCTA(); + assert(_warpsPerCTA.size() == 2); auto order = triton::gpu::getOrder(mmaLayout); + ArrayRef instrShape = mmaLayout.getInstrShape(); + SmallVector warpsPerCTA = {i32_val(_warpsPerCTA[0]), + i32_val(_warpsPerCTA[1])}; + auto shapePerCTA = getShapePerCTA(mmaLayout, shape); + Value threadId = getThreadId(rewriter, loc); Value warpSize = i32_val(triton::gpu::getWarpSize(mmaLayout)); Value laneId = urem(threadId, warpSize); Value warpId = udiv(threadId, warpSize); - SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, order); - unsigned lastAxis = order[order.size() - 1]; - multiDimWarpId[lastAxis] = - urem(multiDimWarpId[lastAxis], i32_val(warpsPerCTA[lastAxis])); - multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 16)); - multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 8)); - Value offWarp0 = mul(multiDimWarpId[0], i32_val(16)); - Value offWarp1 = mul(multiDimWarpId[1], i32_val(8)); + uint32_t repM = (_warpsPerCTA[0] * instrShape[0]) / shapePerCTA[0]; + uint32_t repN = (_warpsPerCTA[1] * instrShape[1]) / shapePerCTA[1]; + + uint32_t warpsM; + if (repM > 1) + warpsM = _warpsPerCTA[0] / repM; + else + warpsM = shape[0] / instrShape[0]; + + uint32_t warpsN; + if (repN > 1) + warpsN = _warpsPerCTA[1] / repN; + else + warpsN = shape[1] / instrShape[1]; + + SmallVector multiDimWarpId(2); + if (mmaLayout.isHopper()) { + // TODO[goostavz]: the tiling order from CTA->warp level is different for + // MMAv2/3. This is a workaround since we don't explicitly have warpGrp + // level in the layout definition, and the tiling order of warpGrp->warp + // must be fixed to meet the HW's needs. We may need to consider to + // explicitly define warpGrpPerCTA for MMAv3 layout. + multiDimWarpId[0] = urem(warpId, warpsPerCTA[0]); + multiDimWarpId[1] = urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]); + } else { + multiDimWarpId = delinearize(rewriter, loc, warpId, _warpsPerCTA, order); + } + Value warpId0 = urem(multiDimWarpId[0], i32_val(warpsM)); + Value warpId1 = urem(multiDimWarpId[1], i32_val(warpsN)); + + Value offWarp0 = mul(warpId0, i32_val(instrShape[0])); + Value offWarp1 = mul(warpId1, i32_val(instrShape[1])); SmallVector multiDimBase(2); multiDimBase[0] = add(udiv(laneId, i32_val(4)), offWarp0); @@ -930,17 +1125,23 @@ class ConvertTritonGPUOpToLLVMPatternBase { } SmallVector> - emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout, + emitOffsetForMmaLayoutV3(const MmaEncodingAttr &mmaLayout, RankedTensorType type) const { auto shape = type.getShape(); + auto shapePerCTA = getShapePerCTA(mmaLayout, shape); SmallVector> ret; - - for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) { - for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) { - ret.push_back({i, j}); - ret.push_back({i, j + 1}); - ret.push_back({i + 8, j}); - ret.push_back({i + 8, j + 1}); + ArrayRef instrShape = mmaLayout.getInstrShape(); + + for (unsigned i = 0; i < shapePerCTA[0]; + i += getShapePerCTATile(mmaLayout)[0]) { + for (unsigned j = 0; j < shapePerCTA[1]; + j += getShapePerCTATile(mmaLayout)[1]) { + for (unsigned k = 0; k < instrShape[1]; k += 8) { + ret.push_back({i, j + k}); + ret.push_back({i, j + k + 1}); + ret.push_back({i + 8, j + k}); + ret.push_back({i + 8, j + k + 1}); + } } } return ret; @@ -959,28 +1160,30 @@ class ConvertTritonGPUOpToLLVMPatternBase { assert(_warpsPerCTA.size() == 2); SmallVector warpsPerCTA = {i32_val(_warpsPerCTA[0]), i32_val(_warpsPerCTA[1])}; + int nonKDim = mfmaLayout.getNonKDim(); Value threadId = getThreadId(rewriter, loc); Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout)); Value laneId = urem(threadId, warpSize); Value warpId = udiv(threadId, warpSize); - Value warpId0 = urem(urem(warpId, warpsPerCTA[0]), i32_val(shape[0] / 32)); + Value warpId0 = + urem(urem(warpId, warpsPerCTA[0]), i32_val(shape[0] / nonKDim)); Value warpId1 = urem(urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]), - i32_val(shape[1] / 32)); + i32_val(shape[1] / nonKDim)); - Value offWarp0 = mul(warpId0, i32_val(32)); - Value offWarp1 = mul(warpId1, i32_val(32)); + Value offWarp0 = mul(warpId0, i32_val(nonKDim)); + Value offWarp1 = mul(warpId1, i32_val(nonKDim)); SmallVector multiDimBase(2); if (mfmaLayout.getIsTransposed()) { multiDimBase[1] = - add(mul(i32_val(4), udiv(laneId, i32_val(32))), offWarp1); - multiDimBase[0] = add(urem(laneId, i32_val(32)), offWarp0); + add(mul(i32_val(4), udiv(laneId, i32_val(nonKDim))), offWarp1); + multiDimBase[0] = add(urem(laneId, i32_val(nonKDim)), offWarp0); } else { multiDimBase[0] = - add(mul(i32_val(4), udiv(laneId, i32_val(32))), offWarp0); - multiDimBase[1] = add(urem(laneId, i32_val(32)), offWarp1); + add(mul(i32_val(4), udiv(laneId, i32_val(nonKDim))), offWarp0); + multiDimBase[1] = add(urem(laneId, i32_val(nonKDim)), offWarp1); } return multiDimBase; } @@ -988,19 +1191,20 @@ class ConvertTritonGPUOpToLLVMPatternBase { SmallVector> emitOffsetForMfmaLayout(const MfmaEncodingAttr &mfmaLayout, RankedTensorType type) const { - auto tensorShape = type.getShape(); SmallVector> offsets; - auto shapePerCta = getShapePerCTA(mfmaLayout); + auto shapePerCTA = getShapePerCTA(mfmaLayout, tensorShape); + auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); - SmallVector numCTAPerDim(2); + SmallVector numWarpsPerDim(2); for (unsigned d = 0; d < 2; ++d) { - unsigned inPerCTA = std::min(tensorShape[d], shapePerCta[d]); - numCTAPerDim[d] = ceil(tensorShape[d], inPerCTA); + unsigned inPerCTA = std::min(tensorShape[d], shapePerCTA[d]); + unsigned inPerWarp = ceil(inPerCTA, warpsPerCTA[d]); + numWarpsPerDim[d] = ceil(inPerWarp, mfmaLayout.getNonKDim()); } - for (unsigned i = 0; i < numCTAPerDim[0]; ++i) { - for (unsigned j = 0; j < numCTAPerDim[1]; ++j) { + for (unsigned i = 0; i < numWarpsPerDim[0]; ++i) { + for (unsigned j = 0; j < numWarpsPerDim[1]; ++j) { emitMfmaOffsetForCTA(mfmaLayout, offsets, i, j); } } @@ -1011,13 +1215,15 @@ class ConvertTritonGPUOpToLLVMPatternBase { // [elemsPerThread X rank] index matrix. SmallVector> emitIndicesForDistributedLayout( Location loc, ConversionPatternRewriter &rewriter, Attribute layout, - RankedTensorType type) const { + RankedTensorType type, bool withCTAOffset) const { // step 1, delinearize threadId to get the base index - auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, type); + auto multiDimBase = + emitBaseIndexForLayout(loc, rewriter, layout, type, withCTAOffset); // step 2, get offset of each element auto offset = emitOffsetForLayout(layout, type); - // step 3, add offset to base, and reorder the sequence of indices to - // guarantee that elems in the same sizePerThread are adjacent in order + // step 3, add offset to base, and reorder the sequence + // of indices to guarantee that elems in the same + // sizePerThread are adjacent in order auto shape = type.getShape(); unsigned rank = shape.size(); unsigned elemsPerThread = offset.size(); @@ -1186,6 +1392,7 @@ class ConvertTritonGPUOpToLLVMPatternBase { TritonGPUToLLVMTypeConverter *converter; ModuleAllocation *allocation; IndexCacheInfo indexCacheInfo; + mlir::triton::gpu::TMAMetadataTy *tmaMetadata; }; template @@ -1200,6 +1407,12 @@ class ConvertTritonGPUOpToLLVMPattern : ConvertOpToLLVMPattern(typeConverter, benefit), ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {} + explicit ConvertTritonGPUOpToLLVMPattern( + TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation) {} + explicit ConvertTritonGPUOpToLLVMPattern( TritonGPUToLLVMTypeConverter &typeConverter, IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1) @@ -1208,16 +1421,17 @@ class ConvertTritonGPUOpToLLVMPattern explicit ConvertTritonGPUOpToLLVMPattern( TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation, - PatternBenefit benefit = 1) + IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(typeConverter, benefit), - ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation) {} + ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, + indexCacheInfo) {} explicit ConvertTritonGPUOpToLLVMPattern( TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation, - IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1) + mlir::triton::gpu::TMAMetadataTy *tmaMetadata, PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(typeConverter, benefit), ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, - indexCacheInfo) {} + tmaMetadata) {} protected: TritonGPUToLLVMTypeConverter *getTypeConverter() const { diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 9400b5ef3142..4bba88ae4079 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -14,48 +14,103 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#ifndef USE_ROCM +#else +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#endif +#include "triton/Tools/Sys/GetPlatform.hpp" +#include "BarrierOpToLLVM.h" +#include "ClusterOpsToLLVM.h" #include "ConvertLayoutOpToLLVM.h" #include "DotOpToLLVM.h" #include "ElementwiseOpToLLVM.h" #include "LoadStoreOpToLLVM.h" #include "ReduceOpToLLVM.h" +#include "RegReallocOpToLLVM.h" #include "ScanOpToLLVM.h" +#include "TensorPtrOpsToLLVM.h" #include "TritonGPUToLLVM.h" +#include "TritonGPUToLLVMBase.h" #include "TypeConverter.h" #include "ViewOpToLLVM.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTTRITONGPUTOLLVM +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + using namespace mlir; using namespace mlir::triton; - -#define GEN_PASS_CLASSES -#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" +namespace ttng = mlir::triton::nvidia_gpu; namespace { +// pass ws related named attrs. +static void addWSNamedAttrs(Operation *op, + ArrayRef attrs) { + for (const NamedAttribute attr : attrs) + if (attr.getName() == "async_agent" || attr.getName() == "agent.mutex_role") + op->setAttr(attr.getName(), attr.getValue()); +} + +#ifdef USE_ROCM +constexpr int LDSSize = 65536; +constexpr int kPtrBitWidth = 64; +#endif class TritonLLVMFunctionConversionTarget : public ConversionTarget { public: - explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx, bool isROCM) + explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx, Target target) : ConversionTarget(ctx) { addLegalDialect(); addLegalDialect(); - if (isROCM) { + switch (target) { + case Target::NVVM: + addLegalDialect(); + break; + case Target::ROCDL: addLegalDialect(); addLegalDialect(); - } else { - addLegalDialect(); + break; } addLegalOp(); } }; +class FoldSplatMaskInInsertAsync : public mlir::RewritePattern { + +public: + FoldSplatMaskInInsertAsync(mlir::MLIRContext *context) + : mlir::RewritePattern( + triton::nvidia_gpu::InsertSliceAsyncV2Op::getOperationName(), 1, + context) {} + + LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto insertOp = cast(op); + if (!insertOp.getMask()) + return failure(); + auto splatOp = insertOp.getMask().getDefiningOp(); + if (!splatOp) + return failure(); + rewriter.updateRootInPlace(insertOp, [&]() { + insertOp.getMaskMutable().assign(splatOp->getOperand(0)); + }); + return mlir::success(); + } +}; + struct ReturnOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -146,6 +201,18 @@ struct FuncOpConversion : public FuncOpConversionBase { if (!allocation.isRoot(funcOp)) amendedFuncOp = amendFuncOp(funcOp, rewriter); + // Collect TMA informations. + unsigned numTMALoad = 0; + funcOp.walk( + [&numTMALoad](triton::nvidia_gpu::InsertSliceAsyncV2Op insertSliceOp) { + numTMALoad++; + }); + unsigned numTMAStore = 0; + funcOp.walk([&numTMAStore](triton::nvidia_gpu::StoreAsyncOp storeAsyncOp) { + numTMAStore++; + }); + unsigned numTMA = numTMALoad + numTMAStore; + auto newFuncOp = convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter); if (!newFuncOp) { return failure(); @@ -173,6 +240,30 @@ struct FuncOpConversion : public FuncOpConversionBase { // The call graph is updated by mapping the old function to the new one. allocation.mapFuncOp(funcOp, newFuncOp); + // Append arguments to receive TMADesc in global memory in the runtime + auto i8PtrTy = LLVM::LLVMPointerType::get( + this->getTypeConverter()->convertType(rewriter.getI8Type()), 1); + auto numArgs = newFuncOp.getBody().front().getNumArguments(); + auto funcTy = newFuncOp.getFunctionType().cast(); + SmallVector newInputsTy(funcTy.getParams().begin(), + funcTy.getParams().end()); + for (unsigned i = 0; i < numTMA; ++i) { + newFuncOp.getBody().front().addArgument(i8PtrTy, funcOp.getLoc()); + newInputsTy.push_back(i8PtrTy); + } + newFuncOp.setType( + LLVM::LLVMFunctionType::get(funcTy.getReturnType(), newInputsTy)); + // required by AxisInfoAnalysis + for (unsigned i = 0; i < numTMA; ++i) { + newFuncOp.setArgAttr(numArgs + i, "tt.divisibility", + rewriter.getIntegerAttr(i32_ty, 1)); + } + + newFuncOp->setAttr(kAttrNumTMALoadDescsName, + rewriter.getIntegerAttr(i32_ty, numTMALoad)); + newFuncOp->setAttr(kAttrNumTMAStoreDescsName, + rewriter.getIntegerAttr(i32_ty, numTMAStore)); + rewriter.eraseOp(funcOp); return success(); } @@ -249,7 +340,6 @@ struct CallOpConversion : public ConvertOpToLLVMPattern { this->getTypeConverter()->packFunctionResults(resultTypes))) return nullptr; } - auto newCallOp = rewriter.create( callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), promotedOperands, callOp->getAttrs()); @@ -282,28 +372,31 @@ struct CallOpConversion : public ConvertOpToLLVMPattern { class TritonLLVMConversionTarget : public ConversionTarget { public: - explicit TritonLLVMConversionTarget(MLIRContext &ctx, bool isROCM) + explicit TritonLLVMConversionTarget(MLIRContext &ctx, Target target) : ConversionTarget(ctx) { addLegalDialect(); - if (isROCM) { + switch (target) { + case Target::NVVM: + addLegalDialect(); + break; + case Target::ROCDL: addLegalDialect(); addLegalDialect(); - } else { - addLegalDialect(); + break; } + addLegalDialect(); addIllegalDialect(); addIllegalDialect(); + addIllegalDialect(); addIllegalDialect(); addLegalOp(); } }; -class ConvertTritonGPUToLLVM - : public ConvertTritonGPUToLLVMBase { - -public: - explicit ConvertTritonGPUToLLVM(int computeCapability, bool isROCM) - : computeCapability(computeCapability), isROCM(isROCM) {} +struct ConvertTritonGPUToLLVM + : public triton::impl::ConvertTritonGPUToLLVMBase { + using ConvertTritonGPUToLLVMBase< + ConvertTritonGPUToLLVM>::ConvertTritonGPUToLLVMBase; void runOnOperation() override { MLIRContext *context = &getContext(); @@ -311,29 +404,65 @@ class ConvertTritonGPUToLLVM mlir::LowerToLLVMOptions option(context); option.overrideIndexBitwidth(32); TritonGPUToLLVMTypeConverter typeConverter(context, option); - TritonLLVMConversionTarget target(*context, isROCM); + TritonLLVMConversionTarget convTarget(*context, target); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); // Preprocess decomposeFp8e4b15Convert(mod); - decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp); + decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs); #ifdef USE_ROCM - decomposeMfmaToDotOperand(mod, numWarps, threadsPerWarp); + decomposeMfmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs); + reduceCvtOpLDSUsage(mod); #endif decomposeBlockedToDotOperand(mod); decomposeInsertSliceAsyncOp(mod); + decomposeMixedModeDotOp(mod); // Allocate shared memory and set barrier ModuleAllocation allocation(mod); ModuleMembarAnalysis membarPass(&allocation); membarPass.run(); + /* Get tensorPtrMap before conversion */ + TensorPtrMapT tensorPtrMap; + mod.walk([&tensorPtrMap]( + mlir::triton::nvidia_gpu::InsertSliceAsyncV2Op insertOp) { + auto src = insertOp.getSrc(); + auto ptrTy = src.getType().dyn_cast(); + if (ptrTy && ptrTy.getPointeeType().isa()) { + auto makeTensorPtrOp = getMakeTensorPtrOp(insertOp.getSrc()); + tensorPtrMap[insertOp.getOperation()] = makeTensorPtrOp; + } + }); + + mod.walk([&tensorPtrMap](mlir::triton::nvidia_gpu::StoreAsyncOp storeOp) { + auto dst = storeOp.getDst(); + auto ptrTy = dst.getType().dyn_cast(); + if (ptrTy && ptrTy.getPointeeType().isa()) { + auto makeTensorPtrOp = getMakeTensorPtrOp(storeOp.getDst()); + tensorPtrMap[storeOp.getOperation()] = makeTensorPtrOp; + } + }); + + // Hack: cleanup + { + RewritePatternSet patterns(context); + patterns.add(context); + SmallVector insertSlices; + mod.walk([&insertSlices](triton::nvidia_gpu::InsertSliceAsyncV2Op op) { + insertSlices.push_back(op); + }); + if (applyOpPatternsAndFold(insertSlices, std::move(patterns)).failed()) + signalPassFailure(); + } + // Lower functions { mlir::LowerToLLVMOptions option(context); TritonGPUToLLVMTypeConverter typeConverter(context, option); - TritonLLVMFunctionConversionTarget funcTarget(*context, isROCM); + TritonLLVMFunctionConversionTarget funcTarget(*context, target); RewritePatternSet funcPatterns(context); funcPatterns.add(typeConverter, numWarps, allocation, /*benefit=*/1); @@ -353,7 +482,7 @@ class ConvertTritonGPUToLLVM { mlir::LowerToLLVMOptions option(context); TritonGPUToLLVMTypeConverter typeConverter(context, option); - TritonLLVMFunctionConversionTarget funcTarget(*context, isROCM); + TritonLLVMFunctionConversionTarget funcTarget(*context, target); RewritePatternSet funcPatterns(context); funcPatterns.add(typeConverter, numWarps, allocation, /*benefit=*/1); @@ -364,9 +493,14 @@ class ConvertTritonGPUToLLVM } ModuleAxisInfoAnalysis axisInfoAnalysis(mod); - // Rewrite ops - RewritePatternSet patterns(context); - // TritonGPU lowering patterns + + // Emit logics to get threadId/blockIds/linearized clusterCTAId etc. and + // cache the values. The reason to do it here is that cluster_ctaid is + // currently implemented via inline asm, and thus cannot be CSEed. + // clusterCTAId will be emitted only when numCTAs is larger than 1, and + // other values will be DCEed if not used hereafter. + bool isWarpSpecialization = + ttng::TritonNvidiaGPUDialect::getWSSupportedAttr(mod); OpBuilder::InsertPoint indexInsertPoint; ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{ &baseIndexCache, &indexCache, &indexInsertPoint}; @@ -374,47 +508,90 @@ class ConvertTritonGPUToLLVM if (axisInfoAnalysis.getNumFunctions() > 1) { indexCacheInfo = {nullptr, nullptr, nullptr}; } - populateTritonGPUToLLVMPatterns(typeConverter, patterns, allocation, - indexCacheInfo, /*benefit=*/1); - populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, allocation, - indexCacheInfo, /*benefit=*/1); - populateDotOpToLLVMPatterns(typeConverter, patterns, allocation, - /*benefit=*/1); - populateElementwiseOpToLLVMPatterns(typeConverter, patterns, /*benefit=*/1); - populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, axisInfoAnalysis, - allocation, indexCacheInfo, - /*benefit=*/1); - populateReduceOpToLLVMPatterns(typeConverter, patterns, allocation, - indexCacheInfo, /*benefit=*/1); - populateScanOpToLLVMPatterns(typeConverter, patterns, allocation, - indexCacheInfo, /*benefit=*/1); - populateViewOpToLLVMPatterns(typeConverter, patterns, /*benefit=*/1); + + // tmaMetadata is absent in a triton-opt unit test, in this case, create a + // local one and dump it after this pass is done. + mlir::triton::gpu::TMAMetadataTy tmaMetaDataDebug; + if (tmaMetadata == nullptr) + tmaMetadata = &tmaMetaDataDebug; + + RewritePatternSet patterns(context); + + auto populatePatterns1 = [&](auto populateFunc) { + populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, + allocation, indexCacheInfo, + /*benefit*/ 10); + }; + + auto populatePatterns2 = [&](auto populateFunc) { + populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, + allocation, /*benefit*/ 10); + }; + + auto populatePatterns3 = [&](auto populateFunc) { + populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, + allocation, indexCacheInfo, tmaMetadata, &tensorPtrMap, + /*benefit*/ 10); + }; + + auto populatePatterns4 = [&](auto populateFunc) { + populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, + allocation, indexCacheInfo, computeCapability, + /*benefit*/ 10); + }; + + populatePatterns1(populateTritonGPUToLLVMPatterns); + populatePatterns1(populateConvertLayoutOpToLLVMPatterns); + populatePatterns2(populateDotOpToLLVMPatterns); + populatePatterns4(populateElementwiseOpToLLVMPatterns); + populatePatterns3(populateLoadStoreOpToLLVMPatterns); + populatePatterns4(populateReduceOpToLLVMPatterns); + populatePatterns1(populateScanOpToLLVMPatterns); + populatePatterns2(populateViewOpToLLVMPatterns); + populatePatterns2(populateBarrierOpToLLVMPatterns); + populatePatterns2(populateTensorPtrOpsToLLVMPatterns); + populatePatterns2(populateClusterOpsToLLVMPatterns); + populatePatterns2(populateRegReallocOpToLLVMPatterns); + + // TODO(thomas): this should probably be done in a separate step to not + // interfere with our own lowering of arith ops. Add arith/math's patterns + // to help convert scalar expression to LLVM. + mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); + mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns); // Native lowering patterns - if (isROCM) { + switch (target) { + case Target::NVVM: + mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns); + break; + case Target::ROCDL: mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, mlir::gpu::amd::HIP); - } else { - mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns); + break; } mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); - if (failed(applyPartialConversion(mod, target, std::move(patterns)))) + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); + + // Fold CTAId when there is only 1 CTA. + if (numCTAs == 1) { + mod.walk([](triton::nvgpu::ClusterCTAIdOp id) { + OpBuilder b(id); + Value zero = LLVM::createConstantI32(id->getLoc(), b, 0); + id.replaceAllUsesWith(zero); + }); + } } private: - using IndexCacheKeyT = std::pair; DenseMap, CacheKeyDenseMapInfo> baseIndexCache; DenseMap>, CacheKeyDenseMapInfo> indexCache; - int computeCapability{}; - bool isROCM{}; - void initSharedMemory(ModuleAllocation &allocation, TritonGPUToLLVMTypeConverter &typeConverter) { ModuleOp mod = getOperation(); @@ -452,7 +629,8 @@ class ConvertTritonGPUToLLVM void decomposeFp8e4b15Convert(ModuleOp mod) const { mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { OpBuilder builder(cvtOp); - if (!getElementTypeOrSelf(cvtOp).isa()) + if (!getElementTypeOrSelf(cvtOp) + .isa()) return; auto shape = cvtOp.getType().cast().getShape(); auto argEncoding = @@ -467,17 +645,20 @@ class ConvertTritonGPUToLLVM auto newCvtType = RankedTensorType::get(shape, F16Ty, cvtEncoding); auto newArg = builder.create( cvtOp.getLoc(), newArgType, cvtOp.getOperand()); + addWSNamedAttrs(newArg, cvtOp->getAttrs()); auto newCvt = builder.create( cvtOp.getLoc(), newCvtType, newArg); + addWSNamedAttrs(newCvt, cvtOp->getAttrs()); auto newRet = builder.create( cvtOp.getLoc(), cvtOp.getType(), newCvt.getResult()); + addWSNamedAttrs(newRet, cvtOp->getAttrs()); cvtOp.replaceAllUsesWith(newRet.getResult()); cvtOp.erase(); }); } - void decomposeMmaToDotOperand(ModuleOp mod, int numWarps, - int threadsPerWarp) const { + void decomposeMmaToDotOperand(ModuleOp mod, int numWarps, int threadsPerWarp, + int numCTAs) const { // Replace `mma -> dot_op` with `mma -> blocked -> dot_op` // unless certain conditions are met mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { @@ -493,11 +674,13 @@ class ConvertTritonGPUToLLVM dstType.getShape(), dstType.getElementType(), triton::gpu::BlockedEncodingAttr::get( mod.getContext(), srcType.getShape(), getSizePerThread(srcMma), - getOrder(srcMma), numWarps, threadsPerWarp)); + getOrder(srcMma), numWarps, threadsPerWarp, numCTAs)); auto tmp = builder.create( cvtOp.getLoc(), tmpType, cvtOp.getOperand()); + addWSNamedAttrs(tmp, cvtOp->getAttrs()); auto newConvert = builder.create( cvtOp.getLoc(), dstType, tmp); + addWSNamedAttrs(newConvert, cvtOp->getAttrs()); cvtOp.replaceAllUsesWith(newConvert.getResult()); cvtOp.erase(); } @@ -505,8 +688,8 @@ class ConvertTritonGPUToLLVM } #ifdef USE_ROCM - void decomposeMfmaToDotOperand(ModuleOp mod, int numWarps, - int threadsPerWarp) const { + void decomposeMfmaToDotOperand(ModuleOp mod, int numWarps, int threadsPerWarp, + int numCTAs) const { // Replace `mfma -> dot_op` with `mfma -> blocked -> dot_op` // unless certain conditions are met mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { @@ -522,7 +705,7 @@ class ConvertTritonGPUToLLVM dstType.getShape(), dstType.getElementType(), triton::gpu::BlockedEncodingAttr::get( mod.getContext(), srcType.getShape(), getSizePerThread(srcMfma), - getOrder(srcMfma), numWarps, threadsPerWarp)); + getOrder(srcMfma), numWarps, threadsPerWarp, numCTAs)); auto tmp = builder.create( cvtOp.getLoc(), tmpType, cvtOp.getOperand()); auto newConvert = builder.create( @@ -532,6 +715,151 @@ class ConvertTritonGPUToLLVM } }); } + + int getCvtOpLDSUsage(triton::gpu::ConvertLayoutOp &cvtOp) const { + unsigned inVec = 0; + unsigned outVec = 0; + auto smemShape = getScratchConfigForCvtLayout(cvtOp, inVec, outVec); + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, + std::multiplies{}); + auto srcType = cvtOp.getOperand().getType().cast(); + auto bytes = + srcType.getElementType().isa() + ? elems * kPtrBitWidth / 8 + : elems * std::max(8, srcType.getElementTypeBitWidth()) / 8; + + return bytes; + } + + bool isPowerOfTwo(unsigned x) const { return x && (x & (x - 1)) == 0; } + + std::vector> factorizePowerOf2(int n) const { + assert(isPowerOfTwo(n)); + int x = log2(n); + std::vector> pairs; + + for (int i = 0; i <= x / 2; ++i) { + int j = x - i; + pairs.push_back({pow(2, i), pow(2, j)}); + pairs.push_back({pow(2, j), pow(2, i)}); + } + + return pairs; + } + + std::pair + createNewConvertOps(ModuleOp &mod, OpBuilder &builder, + triton::gpu::ConvertLayoutOp &cvtOp, + std::pair warpsPerCta) const { + unsigned warpsPerCtaX = warpsPerCta.first; + unsigned warpsPerCtaY = warpsPerCta.second; + auto srcType = cvtOp.getOperand().getType().cast(); + auto dstType = cvtOp.getType().cast(); + + auto srcMfma = + srcType.getEncoding().dyn_cast(); + auto newMfmaEnc = triton::gpu::MfmaEncodingAttr::get( + mod.getContext(), srcMfma.getNonKDim(), {warpsPerCtaX, warpsPerCtaY}, + srcMfma.getIsTransposed(), srcMfma.getCTALayout()); + + auto newDstType = RankedTensorType::get( + dstType.getShape(), dstType.getElementType(), dstType.getEncoding()); + auto newSrcType = RankedTensorType::get( + srcType.getShape(), srcType.getElementType(), newMfmaEnc); + + auto tmpCvt = builder.create( + cvtOp.getLoc(), newSrcType, cvtOp.getOperand()); + auto newEpliogueCvt = builder.create( + cvtOp.getLoc(), newDstType, tmpCvt); + + return std::make_pair(tmpCvt, newEpliogueCvt); + } + + // Try to reduce LDS usage of cvt(mfma->blocked) op by changing the shape of + // WarpsPerCta attribute in mfma layout. The implicit LDS usage of + // cvt(mfma->blocked) op depends on the number of warps per CTA that mfma + // layout uses along x dimension and block layout uses across y dimension. + // + // clang-format off + // + // LDS usage of this op is roughly calculated as: + // LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layoput)[1] * sizeof(data_type) + // LDS_USAGE = warpsPerCTA(mfma_layout)[0] * warpsPerCta(blocked_layout)[1] * C, + // where C = 32 * sizePerWarp(blocked_layout)[1] * threadsPerWarp(blocked_layout)[1] * sizeof(data_type) + // + // clang-format on + // + // When LDS_USAGE exceeds the size of LDS, try to lower LDS usage by + // decomposing cvt(mfma->blocked) op into 2 conversions: cvt(mfma->mfma_tmp) + // and cvt(mfma_tmp->blocked), where mfma_tmp has WarpsPerCta attribute that + // minimizes uses of LDS for these conversions. + void reduceCvtOpLDSUsage(ModuleOp mod) const { + mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + + auto srcType = cvtOp.getOperand().getType().cast(); + auto dstType = cvtOp.getType().cast(); + + auto srcMfma = + srcType.getEncoding().dyn_cast(); + auto dstBlocked = + dstType.getEncoding().dyn_cast(); + + if (!srcMfma || !dstBlocked) { + return; + } + + auto currLDSUsage = getCvtOpLDSUsage(cvtOp); + if (currLDSUsage <= LDSSize) { + return; + } + + unsigned numWarps = + srcMfma.getWarpsPerCTA()[0] * srcMfma.getWarpsPerCTA()[1]; + + triton::gpu::ConvertLayoutOp tmpCvt; + triton::gpu::ConvertLayoutOp newEpliogueCvt; + + // Find all possible shapes of WarpsPerCTA by finding all possible + // factorizations of numWarps. Pick shape for which both conversions in + // decomposition use LDS less than LDSSize and for which sum of LDS usage + // is minimal. If no such shape exists, do not decompose. + unsigned minLDSUsage = 2 * LDSSize; + int minIdx = -1; + auto factorizedNumWarps = factorizePowerOf2(numWarps); + + for (int i = 0; i < factorizedNumWarps.size(); i++) { + auto warpsPerCTAPair = factorizedNumWarps[i]; + std::tie(tmpCvt, newEpliogueCvt) = + createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair); + + int tmpCvtLDS = getCvtOpLDSUsage(tmpCvt); + int newCvtLDS = getCvtOpLDSUsage(newEpliogueCvt); + if (tmpCvtLDS <= LDSSize && newCvtLDS <= LDSSize) { + int LDSUsage = tmpCvtLDS + newCvtLDS; + if (LDSUsage < minLDSUsage) { + minLDSUsage = LDSUsage; + minIdx = i; + } + } + newEpliogueCvt.erase(); + tmpCvt.erase(); + } + + if (minIdx == -1) { + return; + } + + assert(minIdx >= 0 && minIdx < factorizedNumWarps.size()); + auto warpsPerCTAPair = factorizedNumWarps[minIdx]; + std::tie(tmpCvt, newEpliogueCvt) = + createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair); + + cvtOp.replaceAllUsesWith(newEpliogueCvt.getResult()); + cvtOp.erase(); + }); + } + #endif void decomposeBlockedToDotOperand(ModuleOp mod) const { @@ -550,11 +878,14 @@ class ConvertTritonGPUToLLVM dstType.getShape(), dstType.getElementType(), triton::gpu::SharedEncodingAttr::get( mod.getContext(), dstDotOp, srcType.getShape(), - getOrder(srcBlocked), srcType.getElementType())); + srcBlocked.getOrder(), srcBlocked.getCTALayout(), + srcType.getElementType())); auto tmp = builder.create( cvtOp.getLoc(), tmpType, cvtOp.getOperand()); + addWSNamedAttrs(tmp, cvtOp->getAttrs()); auto newConvert = builder.create( cvtOp.getLoc(), dstType, tmp); + addWSNamedAttrs(newConvert, cvtOp->getAttrs()); cvtOp.replaceAllUsesWith(newConvert.getResult()); cvtOp.erase(); } @@ -631,6 +962,7 @@ class ConvertTritonGPUToLLVM /*boundaryCheck=*/nullptr, /*padding=*/nullptr, insertSliceAsyncOp.getCache(), insertSliceAsyncOp.getEvict(), insertSliceAsyncOp.getIsVolatile()); + addWSNamedAttrs(loadOp, insertSliceAsyncOp->getAttrs()); // insert_slice auto axis = insertSliceAsyncOp.getAxis(); @@ -646,6 +978,7 @@ class ConvertTritonGPUToLLVM auto insertSliceOp = builder.create( insertSliceAsyncOp.getLoc(), loadOp, insertSliceAsyncOp.getDst(), offsets, sizes, strides); + addWSNamedAttrs(insertSliceOp, insertSliceAsyncOp->getAttrs()); // Replace insertSliceAsyncOp.replaceAllUsesWith(insertSliceOp.getResult()); @@ -660,8 +993,7 @@ class ConvertTritonGPUToLLVM mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void { #ifdef USE_ROCM - assert(decomposed && - "AsyncWait is not supported for ROCM and should be removed"); + // AsyncWait is not supported for ROCM and should be removed asyncWaitOp.erase(); #else if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability)) { @@ -670,12 +1002,71 @@ class ConvertTritonGPUToLLVM } else if (decomposed) { // Wait for all previous async ops OpBuilder builder(asyncWaitOp); - builder.create(asyncWaitOp.getLoc(), 0); + auto newWaitOp = + builder.create(asyncWaitOp.getLoc(), 0); + addWSNamedAttrs(newWaitOp, asyncWaitOp->getAttrs()); asyncWaitOp.erase(); } #endif }); } + + static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, + Type promotedType) { + Type tensorPromotedType = + operand.getType().cast().cloneWith(std::nullopt, + promotedType); + return builder.create(loc, tensorPromotedType, operand); + } + + // promote operands of dot op if the existing combination is not natively + // supported. + void decomposeMixedModeDotOp(ModuleOp mod) const { + mod.walk([](triton::DotOp dotOp) -> void { + Value D = dotOp.getResult(); + OpBuilder builder(dotOp); + Type AElType = + dotOp.getA().getType().cast().getElementType(); + Type promoteType; + MmaEncodingAttr mmaLayout = D.getType() + .cast() + .getEncoding() + .dyn_cast(); + if (mmaLayout) { + bool isNativeHopperFP8 = + AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ(); + bool isFP8 = isNativeHopperFP8 || AElType.isFloat8E5M2FNUZ() || + AElType.isFloat8E4M3FN(); + if (!isFP8 || (isNativeHopperFP8 && mmaLayout.isHopper())) + return; + promoteType = builder.getF16Type(); +#ifdef USE_ROCM + } else if (MfmaEncodingAttr mfmaLayout = + D.getType() + .cast() + .getEncoding() + .dyn_cast()) { + if (AElType.isBF16() || AElType.isF16() || AElType.isF32() || + AElType.isInteger(8)) + return; + promoteType = builder.getF16Type(); +#endif + } else { + // FMA case. + Type AElType = + dotOp.getA().getType().cast().getElementType(); + Type DElType = D.getType().cast().getElementType(); + if (AElType == DElType) + return; + promoteType = DElType; + } + Location loc = dotOp.getLoc(); + Value promotedA = promoteOperand(builder, loc, dotOp.getA(), promoteType); + Value promotedB = promoteOperand(builder, loc, dotOp.getB(), promoteType); + dotOp.setOperand(0, promotedA); + dotOp.setOperand(1, promotedB); + }); + } }; } // anonymous namespace @@ -683,9 +1074,12 @@ class ConvertTritonGPUToLLVM namespace mlir { namespace triton { +std::unique_ptr> createConvertTritonGPUToLLVMPass() { + return std::make_unique(); +} std::unique_ptr> -createConvertTritonGPUToLLVMPass(int computeCapability, bool isROCM) { - return std::make_unique<::ConvertTritonGPUToLLVM>(computeCapability, isROCM); +createConvertTritonGPUToLLVMPass(const ConvertTritonGPUToLLVMOptions &options) { + return std::make_unique(options); } } // namespace triton diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index a094f5d3f1ba..6afcd26d5932 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -28,6 +28,9 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( addConversion([&](mlir::Float8E4M3B11FNUZType type) -> std::optional { return IntegerType::get(type.getContext(), 8); }); + addConversion([&](mlir::Float8E4M3FNType type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }); addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional { return IntegerType::get(type.getContext(), 8); }); @@ -42,7 +45,27 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( Type TritonGPUToLLVMTypeConverter::convertTritonPointerType( triton::PointerType type) { - // Recursively translate pointee type + auto ctx = type.getContext(); + auto pointeeType = type.getPointeeType(); + if (pointeeType.isa()) { + auto rankedTensorType = pointeeType.cast(); + // struct { offset0, offset1, shape0, shape1, stride0, + // stride1, base_ptr}; + auto eleType = rankedTensorType.getElementType(); + auto shape = rankedTensorType.getShape(); + SmallVector types; + // offsets + for (size_t i = 0; i < shape.size(); ++i) + types.push_back(IntegerType::get(ctx, 32)); + // shapes, strides + for (size_t i = 0; i < 2 * shape.size(); ++i) + types.push_back(IntegerType::get(ctx, 64)); + + types.push_back( + LLVM::LLVMPointerType::get(eleType, type.getAddressSpace())); + + return LLVM::LLVMStructType::getLiteral(ctx, types); + } return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()), type.getAddressSpace()); } @@ -79,6 +102,37 @@ Value TritonGPUToLLVMTypeConverter::packLLElements( return llvmStruct; } +SmallVector TritonGPUToLLVMTypeConverter::packMfmaOperand( + const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc) { + auto tensorTy = srcTy.dyn_cast(); + if (!tensorTy) + return inValues; + auto encoding = tensorTy.getEncoding().dyn_cast(); + if (!(encoding && encoding.getParent().isa())) { + return inValues; + } + + auto structType = this->convertType(srcTy).dyn_cast(); + auto elementTypes = structType.getBody(); + assert(elementTypes.size() > 0); + mlir::VectorType vecTy = elementTypes[0].dyn_cast(); + if (!vecTy) return inValues; + + unsigned size = vecTy.getNumElements(); + + SmallVector result; + for (int i = 0; i < inValues.size(); i += size) { + Value valVec = undef(vecTy); + for (unsigned j = 0; j < size; ++j) { + valVec = insert_element(vecTy, valVec, inValues[i + j], i32_val(j)); + } + result.push_back(valVec); + } + + return result; +} + SmallVector TritonGPUToLLVMTypeConverter::unpackLLElements( Location loc, Value llvmStruct, ConversionPatternRewriter &rewriter, Type type) { @@ -111,7 +165,7 @@ Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct( if (elemTy.isF32()) return elemTy; if (elemTy.isInteger(16)) // aka BF16 - return vec_ty(elemTy, dotOpLayout.getKWidth() / 2); + return vec_ty(elemTy, dotOpLayout.getKWidth()); if (elemTy.isF16()) return vec_ty(elemTy, 4); if (elemTy.isInteger(8)) @@ -122,14 +176,9 @@ Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct( auto mmaParent = dotOpLayout.getParent().dyn_cast(); if (!mmaParent) return elemTy; - if (mmaParent.isAmpere()) { - int bitwidth = elemTy.getIntOrFloatBitWidth(); - assert(bitwidth <= 32); - return IntegerType::get(ctx, 32); - } else { - assert(mmaParent.isVolta()); - return vec_ty(elemTy, 2); - } + int bitwidth = elemTy.getIntOrFloatBitWidth(); + assert(bitwidth <= 32); + return IntegerType::get(ctx, 32); } Type TritonGPUToLLVMTypeConverter::convertTritonTensorType( diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.h b/lib/Conversion/TritonGPUToLLVM/TypeConverter.h index 038363754b40..975808bb17b8 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.h +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.h @@ -26,6 +26,10 @@ class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter { Type type); Type convertTritonTensorType(RankedTensorType type); + + SmallVector packMfmaOperand( + const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc); }; #endif diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 351d31d71855..ff547d6f9514 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -6,19 +6,19 @@ namespace mlir { namespace LLVM { using namespace mlir::triton; -Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) { +Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v) { auto i32ty = rewriter.getIntegerType(32); return rewriter.create(loc, i32ty, IntegerAttr::get(i32ty, v)); } -Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) { +Value createConstantF32(Location loc, OpBuilder &rewriter, float v) { auto type = type::f32Ty(rewriter.getContext()); return rewriter.create(loc, type, rewriter.getF32FloatAttr(v)); } -Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) { +Value createConstantF64(Location loc, OpBuilder &rewriter, float v) { auto type = type::f64Ty(rewriter.getContext()); return rewriter.create(loc, type, rewriter.getF64FloatAttr(v)); @@ -40,6 +40,96 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, builder.getIntegerAttr(ty, value)); } +// A wrapper of LoadDSmemOp when vec = 1 +// (1) Get bitwidth from elemTy +// (2) Create LoadDSmemOp +// (3) Bitcast result from dataTy (u16/u32/u64) back to elemTy +Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId) { + assert(addr.getType().isa() && + "addr must be a pointer type"); + auto ptrTy = addr.getType().cast(); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); + auto elemTy = ptrTy.getElementType(); + unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); + Value ret = + rewriter.create(loc, addr, ctaId, bitwidth); + return bitcast(ret, elemTy); +} + +// A wrapper of LoadDSmemOp when vec > 1 +// (1) Get bitwidth from elemTy +// (2) Create LoadDSmemOp and extract results from retStruct +// (3) Bitcast results from dataTy (u16/u32/u64) back to elemTy +SmallVector createLoadDSmem(Location loc, PatternRewriter &rewriter, + Value addr, Value ctaId, unsigned vec) { + assert(addr.getType().isa() && + "addr must be a pointer type"); + auto ptrTy = addr.getType().cast(); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); + auto elemTy = ptrTy.getElementType(); + unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); + Value retStruct = rewriter.create( + loc, addr, ctaId, bitwidth, vec); + SmallVector retVals; + for (unsigned i = 0; i < vec; ++i) { + auto dataTy = rewriter.getIntegerType(bitwidth); + Value data = extract_val(dataTy, retStruct, i); + retVals.push_back(bitcast(data, elemTy)); + } + return retVals; +} + +// A wrapper of StoreDSmemOp when vec = 1 +// (1) Get bitwidth from elemTy +// (2) Bitcast value from elemTy to dataTy (u16/u32/u64) +// (3) Create StoreDSmemOp +void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, Value value, Value pred) { + assert(addr.getType().isa() && + "addr must be a pointer type"); + auto ptrTy = addr.getType().cast(); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); + auto elemTy = ptrTy.getElementType(); + unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); + auto dataTy = rewriter.getIntegerType(bitwidth); + Value data = bitcast(value, dataTy); + rewriter.create(loc, addr, ctaId, data, pred); +} + +// A wrapper of StoreDSmemOp when vec = 1 and pred = 1 +void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, Value value) { + Value pred = int_val(/*width=*/1, 1); + createStoreDSmem(loc, rewriter, addr, ctaId, value, pred); +} + +// A wrapper of StoreDSmemOp when vec > 1 +// (1) Get bitwidth from elemTy +// (2) Bitcast values from elemTy to dataTy (u16/u32/u64) +// (3) Create StoreDSmemOp +void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, ArrayRef values, Value pred) { + assert(addr.getType().isa() && + "addr must be a pointer type"); + auto ptrTy = addr.getType().cast(); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); + auto elemTy = ptrTy.getElementType(); + unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); + auto dataTy = rewriter.getIntegerType(bitwidth); + SmallVector data; + for (unsigned i = 0; i < values.size(); ++i) + data.push_back(bitcast(values[i], dataTy)); + rewriter.create(loc, addr, ctaId, data, pred); +} + +// A wrapper of StoreDSmemOp when vec > 1 and pred = 1 +void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, ArrayRef values) { + Value pred = int_val(/*width=*/1, 1); + createStoreDSmem(loc, rewriter, addr, ctaId, values, pred); +} + SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, ConversionPatternRewriter &rewriter) { @@ -253,6 +343,15 @@ Value shflUpSync(Location loc, ConversionPatternRewriter &rewriter, Value val, int i, Value laneId) { return commonShflSync(loc, rewriter, val, i, "up", "0x0", laneId); } +Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr) { + PTXBuilder builder; + auto &mov = builder.create("mov")->o("u32"); + auto *destOpr = builder.newOperand("=r"); + auto *sRegOpr = builder.newConstantOperand(sRegStr); + mov(destOpr, sRegOpr); + Value val = builder.launch(b, loc, b.getIntegerType(32), false); + return val; +} Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, StringRef key, StringRef content) { @@ -289,4 +388,10 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, } } // namespace LLVM + +bool isF8(Type eType) { + return eType.isFloat8E5M2FNUZ() or eType.isFloat8E4M3FNUZ() or + eType.isFloat8E5M2() or eType.isFloat8E5M2FNUZ(); +} + } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index be47ce4e20c9..3cfdcd4444da 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -16,6 +16,7 @@ #define trunc(...) rewriter.create(loc, __VA_ARGS__) #define sext(...) rewriter.create(loc, __VA_ARGS__) #define fpext(...) rewriter.create(loc, __VA_ARGS__) +#define trunc(...) rewriter.create(loc, __VA_ARGS__) #define udiv(...) rewriter.create(loc, __VA_ARGS__) #define urem(...) rewriter.create(loc, __VA_ARGS__) #define add(...) rewriter.create(loc, __VA_ARGS__) @@ -33,6 +34,7 @@ #define umin(...) rewriter.create(loc, __VA_ARGS__) #define fmin(...) rewriter.create(loc, __VA_ARGS__) #define shl(...) rewriter.create(loc, __VA_ARGS__) +#define lshr(...) rewriter.create(loc, __VA_ARGS__) #define and_(...) rewriter.create(loc, __VA_ARGS__) #define or_(...) rewriter.create(loc, __VA_ARGS__) #define xor_(...) rewriter.create(loc, __VA_ARGS__) @@ -51,6 +53,8 @@ rewriter.create(loc, __VA_ARGS__) #define load(...) rewriter.create(loc, __VA_ARGS__) #define store(val, ptr) rewriter.create(loc, val, ptr) +#define load_dsmem(...) LLVM::createLoadDSmem(loc, rewriter, __VA_ARGS__) +#define store_dsmem(...) LLVM::createStoreDSmem(loc, rewriter, __VA_ARGS__) #define fcmp_ogt(lhs, rhs) \ rewriter.create(loc, rewriter.getI1Type(), \ LLVM::FCmpPredicate::ogt, lhs, rhs) @@ -83,6 +87,15 @@ #define select(...) rewriter.create(loc, __VA_ARGS__) #define address_of(...) rewriter.create(loc, __VA_ARGS__) #define barrier() rewriter.create(loc) +#define barSync(rewriter, op, bar, numThreads) \ + do { \ + ::mlir::triton::PTXBuilder ptxBuilder; \ + auto &barSyncOp = *ptxBuilder.create<>("bar.sync"); \ + barSyncOp(ptxBuilder.newConstantOperand(bar), \ + ptxBuilder.newConstantOperand(numThreads)); \ + auto voidTy = void_ty(op->getContext()); \ + ptxBuilder.launch(rewriter, op->getLoc(), voidTy); \ + } while (0) #define undef(...) rewriter.create(loc, __VA_ARGS__) #define null(...) rewriter.create(loc, __VA_ARGS__) #define call(...) rewriter.create(loc, __VA_ARGS__) @@ -92,6 +105,8 @@ #define i64_ty rewriter.getIntegerType(64) #define i32_ty rewriter.getIntegerType(32) #define i16_ty rewriter.getIntegerType(16) +#define i32_ty rewriter.getIntegerType(32) +#define i64_ty rewriter.getIntegerType(64) #define ui32_ty rewriter.getIntegerType(32, false) #define f16_ty rewriter.getF16Type() #define bf16_ty rewriter.getBF16Type() @@ -182,13 +197,13 @@ T getLinearIndex(llvm::ArrayRef multiDimIndex, llvm::ArrayRef shape, namespace LLVM { using namespace mlir::triton; -Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v); +Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v); /// Create a 32-bit float constant. -Value createConstantF32(Location loc, PatternRewriter &rewriter, float v); +Value createConstantF32(Location loc, OpBuilder &rewriter, float v); /// Create a 64-bit float constant. -Value createConstantF64(Location loc, PatternRewriter &rewriter, float v); +Value createConstantF64(Location loc, OpBuilder &rewriter, float v); /// Create an index type constant. Value createIndexConstant(OpBuilder &builder, Location loc, @@ -198,6 +213,28 @@ Value createIndexConstant(OpBuilder &builder, Location loc, Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, int64_t value); +/// Usage of macro load_dsmem +/// (1) load_dsmem(addr, ctaId) +/// (2) load_dsmem(addr, ctaId, vec) +Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId); +SmallVector createLoadDSmem(Location loc, PatternRewriter &rewriter, + Value addr, Value ctaId, unsigned vec); + +/// Usage of macro store_dsmem +/// (1) store_dsmem(addr, ctaId, value, pred) +/// (2) store_dsmem(addr, ctaId, value) +/// (3) store_dsmem(addr, ctaId, values, pred) +/// (4) store_dsmem(addr, ctaId, values) +void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, Value value, Value pred); +void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, Value value); +void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, ArrayRef values, Value pred); +void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, ArrayRef values); + /// Helper function to get strides from a given shape and its order SmallVector getStridesFromShapeAndOrder(ArrayRef shape, ArrayRef order, @@ -296,10 +333,14 @@ Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val, Value shflUpSync(Location loc, ConversionPatternRewriter &rewriter, Value val, int i, Value laneId); +Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr); Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, StringRef key, StringRef content); } // namespace LLVM + +bool isF8(Type eType); + } // namespace mlir #endif diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 553fd22bf1df..fdd47f2de196 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -104,6 +104,7 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern { using OpAdaptor = typename CatOp::Adaptor; explicit CatOpConversion(TritonGPUToLLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} @@ -138,6 +139,7 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern { struct ViewOpConversion : public ConvertTritonGPUOpToLLVMPattern { using OpAdaptor = typename ViewOp::Adaptor; explicit ViewOpConversion(TritonGPUToLLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} @@ -159,6 +161,7 @@ struct ExpandDimsOpConversion : public ConvertTritonGPUOpToLLVMPattern { using OpAdaptor = typename ExpandDimsOp::Adaptor; explicit ExpandDimsOpConversion(TritonGPUToLLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} @@ -221,7 +224,9 @@ struct TransOpConversion }; void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, + RewritePatternSet &patterns, int numWarps, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + ModuleAllocation &allocation, PatternBenefit benefit) { patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.h index 1a8aef39148f..431e8efece40 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.h @@ -7,7 +7,9 @@ using namespace mlir; using namespace mlir::triton; void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, + RewritePatternSet &patterns, int numWarps, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + ModuleAllocation &allocation, PatternBenefit benefit); #endif diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index f150ed422796..b03d76ac4bdb 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -10,6 +10,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Target/PTX/TmaMetadata.h" #include "llvm/ADT/APSInt.h" #include @@ -240,10 +241,19 @@ struct TritonExpandDimsPattern retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.getAxis(), 1); SmallVector retOrder(retShape.size()); std::iota(retOrder.begin(), retOrder.end(), 0); + + auto argCTALayout = argEncoding.getCTALayout(); + auto retCTAsPerCGA = insertOne(argCTALayout.getCTAsPerCGA(), op.getAxis()); + auto retCTASplitNum = + insertOne(argCTALayout.getCTASplitNum(), op.getAxis()); + auto retCTAOrder = insertOrder(argCTALayout.getCTAOrder(), op.getAxis()); + auto retCTALayout = triton::gpu::CTALayoutAttr::get( + getContext(), retCTAsPerCGA, retCTASplitNum, retCTAOrder); + triton::gpu::BlockedEncodingAttr retEncoding = triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread, retThreadsPerWarp, retWarpsPerCTA, - retOrder); + retOrder, retCTALayout); // convert operand to slice of return type Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get( getContext(), op.getAxis(), retEncoding); @@ -257,6 +267,26 @@ struct TritonExpandDimsPattern adaptor.getAttributes()); return success(); } + +private: + template + SmallVector insertOne(ArrayRef vec, unsigned axis) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + axis, 1); + return res; + } + + // Example: order = [ 0, 2, 1, 3], dim = 2 + // resOrder = [2, 0, 3, 1, 4] + SmallVector insertOrder(ArrayRef order, + unsigned axis) const { + SmallVector resOrder(order.begin(), order.end()); + for (unsigned i = 0; i < resOrder.size(); ++i) + if (resOrder[i] >= axis) + ++resOrder[i]; + resOrder.insert(resOrder.begin(), axis); + return resOrder; + } }; struct TritonDotPattern : public OpConversionPattern { @@ -270,6 +300,7 @@ struct TritonDotPattern : public OpConversionPattern { auto typeConverter = getTypeConverter(); int numWarps = typeConverter->getNumWarps(); int threadsPerWarp = typeConverter->getThreadsPerWarp(); + int numCTAs = typeConverter->getNumCTAs(); SmallVector retSizePerThread = {1, 1}; if (origShape[0] * origShape[1] / (numWarps * threadsPerWarp) >= 4) @@ -279,7 +310,7 @@ struct TritonDotPattern : public OpConversionPattern { SmallVector retOrder = {1, 0}; Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get( getContext(), origShape, retSizePerThread, retOrder, numWarps, - threadsPerWarp); + threadsPerWarp, numCTAs); RankedTensorType retType = RankedTensorType::get(origShape, origType.getElementType(), dEncoding); // a & b must be of smem layout @@ -354,9 +385,9 @@ struct TritonCatPattern : public OpConversionPattern { newRetSizePerThread[retOrder[0]] *= newRetTotalElemsPerThread / retTotalElemsPerThread; triton::gpu::BlockedEncodingAttr newRetEncoding = - triton::gpu::BlockedEncodingAttr::get(getContext(), newRetSizePerThread, - retThreadsPerWarp, retWarpsPerCTA, - retOrder); + triton::gpu::BlockedEncodingAttr::get( + getContext(), newRetSizePerThread, retThreadsPerWarp, + retWarpsPerCTA, retOrder, retEncoding.getCTALayout()); auto newRetType = RankedTensorType::get(retShape, retType.getElementType(), newRetEncoding); addNamedAttrs(rewriter.replaceOpWithNewOp( @@ -386,8 +417,12 @@ struct TritonTransPattern : public OpConversionPattern { if (auto srcBlockedEncoding = srcEncoding.dyn_cast()) llvm::copy(srcBlockedEncoding.getOrder(), order.begin()); - srcEncoding = - triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order); + // TODO(Qingyi): need to check whether the CTALayout of srcEncoding should + // be used here. For tests where numCTAs = 1, this is not a problem since + // all CTALayouts are the same. + auto CTALayout = triton::gpu::getCTALayout(srcEncoding); + srcEncoding = triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, + order, CTALayout); srcType = RankedTensorType::get(srcType.getShape(), srcType.getElementType(), srcEncoding); src = rewriter.create(src.getLoc(), srcType, @@ -463,24 +498,6 @@ struct TritonAtomicRMWPattern } }; -template -struct TritonExternElementwisePattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - using OpConversionPattern::typeConverter; - typedef typename OpConversionPattern::OpAdaptor OpAdaptor; - - LogicalResult - matchAndRewrite(T op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), - adaptor.getArgs(), adaptor.getLibname(), - adaptor.getLibpath(), adaptor.getSymbol()), - adaptor.getAttributes()); - return success(); - } -}; - template struct TritonGenericPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -658,26 +675,26 @@ class TritonReturnOpPattern : public OpConversionPattern { }; void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, - RewritePatternSet &patterns) { + RewritePatternSet &patterns, unsigned numCTAs) { MLIRContext *context = patterns.getContext(); - patterns - .insert< // TODO: view should have custom pattern that views the layout - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, TritonBroadcastPattern, - TritonGenericPattern, TritonCatPattern, - TritonReducePattern, TritonReduceReturnPattern, TritonScanPattern, - TritonScanReturnPattern, TritonTransPattern, TritonExpandDimsPattern, - TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, - TritonStorePattern, - TritonExternElementwisePattern, - TritonExternElementwisePattern, - TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern, - TritonFuncOpPattern, TritonReturnOpPattern, TritonCallOpPattern>( - typeConverter, context); + patterns.insert< // TODO: view should have custom pattern that views the + // layout + TritonGenericPattern, + TritonGenericPattern, + TritonGenericPattern, + TritonGenericPattern, + TritonGenericPattern, + TritonGenericPattern, + TritonGenericPattern, + TritonGenericPattern, TritonBroadcastPattern, + TritonGenericPattern, TritonCatPattern, + TritonGenericPattern, TritonReducePattern, + TritonReduceReturnPattern, TritonScanPattern, TritonScanReturnPattern, + TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern, + TritonDotPattern, TritonLoadPattern, TritonStorePattern, + TritonGenericPattern, TritonPrintPattern, + TritonAssertPattern, TritonAtomicRMWPattern, TritonFuncOpPattern, + TritonReturnOpPattern, TritonCallOpPattern>(typeConverter, context); } // @@ -889,16 +906,20 @@ class ConvertTritonToTritonGPU public: ConvertTritonToTritonGPU() = default; // constructor with some parameters set explicitly. - ConvertTritonToTritonGPU(int numWarps, int threadsPerWarp) { + ConvertTritonToTritonGPU(int numWarps, int threadsPerWarp, int numCTAs, + int computeCapability) { this->numWarps = numWarps; this->threadsPerWarp = threadsPerWarp; + this->numCTAs = numCTAs; + this->computeCapability = computeCapability; } void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); // type converter - TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp); + TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp, + numCTAs); TritonGPUConversionTarget target(*context, typeConverter); // rewrite patterns RewritePatternSet patterns(context); @@ -906,7 +927,7 @@ class ConvertTritonToTritonGPU populateStdPatternsAndLegality(typeConverter, patterns, target); populateArithPatternsAndLegality(typeConverter, patterns, target); populateMathPatternsAndLegality(typeConverter, patterns, target); - populateTritonPatterns(typeConverter, patterns); + populateTritonPatterns(typeConverter, patterns, numCTAs); // TODO: can we use // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? populateSCFPatterns(typeConverter, patterns); @@ -925,6 +946,13 @@ class ConvertTritonToTritonGPU AttrNumThreadsPerWarp, IntegerAttr::get(i32_ty, llvm::APInt(32, threadsPerWarp.getValue()))); + mod->setAttr(AttrNumCTAsName, + IntegerAttr::get(i32_ty, llvm::APInt(32, numCTAs.getValue()))); + + mod->setAttr(AttrComputeCapabilityName, + IntegerAttr::get( + i32_ty, llvm::APInt(32, computeCapability.getValue()))); + // update layouts // broadcast src => multicast, dst => broadcasted // if (failed(target.refineLayouts(mod, numWarps))) @@ -936,8 +964,11 @@ class ConvertTritonToTritonGPU std::unique_ptr> mlir::triton::createConvertTritonToTritonGPUPass(int numWarps, - int threadsPerWarp) { - return std::make_unique<::ConvertTritonToTritonGPU>(numWarps, threadsPerWarp); + int threadsPerWarp, + int numCTAs, + int computeCapability) { + return std::make_unique<::ConvertTritonToTritonGPU>( + numWarps, threadsPerWarp, numCTAs, computeCapability); } std::unique_ptr> diff --git a/lib/Dialect/CMakeLists.txt b/lib/Dialect/CMakeLists.txt index 27cb65ce5101..02d764601056 100644 --- a/lib/Dialect/CMakeLists.txt +++ b/lib/Dialect/CMakeLists.txt @@ -1,2 +1,4 @@ add_subdirectory(Triton) add_subdirectory(TritonGPU) +add_subdirectory(TritonNvidiaGPU) +add_subdirectory(NVGPU) diff --git a/lib/Dialect/NVGPU/CMakeLists.txt b/lib/Dialect/NVGPU/CMakeLists.txt new file mode 100644 index 000000000000..f33061b2d87c --- /dev/null +++ b/lib/Dialect/NVGPU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/lib/Dialect/NVGPU/IR/CMakeLists.txt b/lib/Dialect/NVGPU/IR/CMakeLists.txt new file mode 100644 index 000000000000..4e9e1ada172c --- /dev/null +++ b/lib/Dialect/NVGPU/IR/CMakeLists.txt @@ -0,0 +1,9 @@ +add_mlir_dialect_library(NVGPUIR + Dialect.cpp + + DEPENDS + NVGPUTableGen + NVGPUAttrDefsIncGen + + LINK_LIBS PUBLIC +) diff --git a/lib/Dialect/NVGPU/IR/Dialect.cpp b/lib/Dialect/NVGPU/IR/Dialect.cpp new file mode 100644 index 000000000000..a3c11f87a62f --- /dev/null +++ b/lib/Dialect/NVGPU/IR/Dialect.cpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" + +// clang-format off +#include "triton/Dialect/NVGPU/IR/Dialect.h" +#include "triton/Dialect/NVGPU/IR/Dialect.cpp.inc" +// clang-format on + +using namespace mlir; +using namespace mlir::triton::nvgpu; + +void LoadDSmemOp::build(OpBuilder &builder, OperationState &state, + Type resultTy, Value addr, Value ctaId) { + unsigned vec, bitwidth; + if (auto structTy = resultTy.dyn_cast()) { + auto types = structTy.getBody(); + assert(types.size() > 0 && "Invalid result type of LoadDSmemOp"); + vec = types.size(); + for (unsigned i = 0; i < vec; ++i) + assert(types[0] == types[i]); + bitwidth = types[0].getIntOrFloatBitWidth(); + } else { + vec = 1; + bitwidth = resultTy.getIntOrFloatBitWidth(); + } + build(builder, state, resultTy, addr, ctaId, bitwidth, vec); +} + +void LoadDSmemOp::build(OpBuilder &builder, OperationState &state, Value addr, + Value ctaId, unsigned bitwidth, unsigned vec) { + Type resultTy = builder.getIntegerType(bitwidth); + if (vec > 1) { + SmallVector types(vec, resultTy); + resultTy = LLVM::LLVMStructType::getLiteral(builder.getContext(), types); + } + build(builder, state, resultTy, addr, ctaId, bitwidth, vec); +} + +void LoadDSmemOp::build(OpBuilder &builder, OperationState &state, Value addr, + Value ctaId, unsigned bitwidth) { + build(builder, state, addr, ctaId, bitwidth, /*vec*/ 1); +} + +void StoreDSmemOp::build(OpBuilder &builder, OperationState &state, Value addr, + Value ctaId, Value value, Value pred) { + SmallVector values = {value}; + build(builder, state, addr, ctaId, values, pred); +} + +unsigned StoreDSmemOp::getBitwidth() { + auto addrTy = getAddr().getType(); + assert(addrTy.isa() && "addr must be a pointer type"); + auto elemTy = addrTy.cast().getElementType(); + return elemTy.getIntOrFloatBitWidth(); +} + +unsigned StoreDSmemOp::getVec() { return getValues().size(); } + +static LogicalResult verify(mlir::triton::nvgpu::TMALoadTiledOp op) { + return success(); +} + +static LogicalResult verify(mlir::triton::nvgpu::TMALoadIm2colOp op) { + return success(); +} + +static LogicalResult verify(mlir::triton::nvgpu::WGMMAOp op) { + return success(); +} + +void mlir::triton::nvgpu::NVGPUDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/NVGPU/IR/Ops.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "triton/Dialect/NVGPU/IR/Ops.cpp.inc" +#include "triton/Dialect/NVGPU/IR/OpsEnums.cpp.inc" diff --git a/lib/Dialect/Triton/IR/Dialect.cpp b/lib/Dialect/Triton/IR/Dialect.cpp index 086279123722..0f234f905178 100644 --- a/lib/Dialect/Triton/IR/Dialect.cpp +++ b/lib/Dialect/Triton/IR/Dialect.cpp @@ -7,6 +7,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Transforms/InliningUtils.h" diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index d988b372dfba..2d7bec31cf82 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -80,6 +80,16 @@ void LoadOp::print(OpAsmPrinter &printer) { printer.printStrippedAttrOrType(getResult().getType()); } +void LoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), getPtr(), + SideEffects::DefaultResource::get()); + if (getIsVolatile()) + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); +} + ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { // Parse operands SmallVector allOperands; @@ -401,8 +411,10 @@ mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes( LogicalResult mlir::triton::DotOp::verify() { auto aTy = getOperand(0).getType().cast(); auto bTy = getOperand(1).getType().cast(); - if (aTy.getElementType() != bTy.getElementType()) - return emitError("element types of operands A and B must match"); + if (aTy.getElementType().getIntOrFloatBitWidth() != + bTy.getElementType().getIntOrFloatBitWidth()) + return emitError( + "element types of operands A and B must have same bit width"); auto aEncoding = aTy.getEncoding(); auto bEncoding = bTy.getEncoding(); if (!aEncoding && !bEncoding) @@ -416,6 +428,16 @@ LogicalResult mlir::triton::DotOp::verify() { bEncoding); } +//-- MakeRangeOp -- +OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) { + // make_range(start, start + 1) -> constant(start) + if (adaptor.getStart() + 1 == adaptor.getEnd()) { + auto shapedType = getType().cast(); + return SplatElementsAttr::get(shapedType, adaptor.getStartAttr()); + } + return {}; +} + //-- ReduceOp -- static mlir::LogicalResult inferReduceReturnShape(const RankedTensorType &argTy, const Type &retEltTy, @@ -870,5 +892,29 @@ LogicalResult triton::ReturnOp::verify() { return success(); } +// -- ElementwiseInlineAsmOp -- +void ElementwiseInlineAsmOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +// -- ExternElementwiseOp -- +void ExternElementwiseOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + } // namespace triton } // namespace mlir diff --git a/lib/Dialect/Triton/IR/Traits.cpp b/lib/Dialect/Triton/IR/Traits.cpp index 47893fd2ced1..ecfce4a27fcc 100644 --- a/lib/Dialect/Triton/IR/Traits.cpp +++ b/lib/Dialect/Triton/IR/Traits.cpp @@ -7,15 +7,17 @@ using namespace mlir; static LogicalResult verifySameEncoding(Type typeA, Type typeB, bool allowTensorPointerType) { + // TODO(Keren): the allowTensorPointerType argument is a hack to allow. + // The type checking code is kind of a mess with the current design. auto getEncoding = [=](Type type) -> Attribute { - auto rankedType = type.dyn_cast(); - if (allowTensorPointerType) { - if (auto ptrType = type.dyn_cast()) - rankedType = ptrType.getPointeeType().dyn_cast(); - } else { + Attribute ret; + if (auto tensorType = dyn_cast(type)) { + ret = tensorType.getEncoding(); + } + if (!allowTensorPointerType) { assert(!triton::isTensorPointerType(type)); } - return rankedType ? rankedType.getEncoding() : Attribute(); + return ret; }; auto encodingA = getEncoding(typeA); auto encodingB = getEncoding(typeB); @@ -127,7 +129,16 @@ OpTrait::impl::verifySameLoadStoreOperandsAndResultShape(Operation *op) { bool OpTrait::impl::verifyLoadStorePointerAndValueType(Type valueType, Type ptrType) { if (triton::isTensorPointerType(ptrType)) { - return ptrType.cast().getPointeeType() == valueType; + // The encoding of tensor pointers is meaningless, we only check shapes and + // the type of elements + auto tensorAType = ptrType.cast() + .getPointeeType() + .cast(); + if (!isa(valueType)) + return false; + auto tensorBType = valueType.cast(); + return tensorAType.getShape() == tensorBType.getShape() && + tensorAType.getElementType() == tensorBType.getElementType(); } else if (auto rankedType = ptrType.dyn_cast()) { if (auto elementPtrType = dyn_cast(rankedType.getElementType())) { diff --git a/lib/Dialect/Triton/IR/Types.cpp b/lib/Dialect/Triton/IR/Types.cpp index 104482083a27..aae8d862e3de 100644 --- a/lib/Dialect/Triton/IR/Types.cpp +++ b/lib/Dialect/Triton/IR/Types.cpp @@ -27,15 +27,20 @@ Type PointerType::parse(AsmParser &parser) { if (parser.parseType(pointeeType)) return Type(); + int addressSpace = 1; + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseInteger(addressSpace)) + return Type(); + } + if (parser.parseGreater()) return Type(); - // TODO: also print address space? - return PointerType::get(pointeeType, 1); + return PointerType::get(pointeeType, addressSpace); } void PointerType::print(AsmPrinter &printer) const { - printer << "<" << getPointeeType() << ">"; + printer << "<" << getPointeeType() << ", " << getAddressSpace() << ">"; } namespace mlir { @@ -99,6 +104,10 @@ bool isTensorPointerType(Type type) { return false; } +bool isTensorOrTensorPointerType(Type type) { + return type.isa() || isTensorPointerType(type); +} + Type getElementTypeOfTensorPointerType(Type type) { if (auto ptrType = type.dyn_cast()) if (auto tensorTy = ptrType.getPointeeType().dyn_cast()) diff --git a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp index b57a3d1fe7a6..44eb9e0b1a8d 100644 --- a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -11,6 +11,8 @@ using namespace mlir; #define GEN_PASS_CLASSES #include "triton/Dialect/Triton/Transforms/Passes.h.inc" +namespace { + /// An additional struct to record the meta information of operations /// with tensor pointers struct RewritedInfo { @@ -186,6 +188,8 @@ struct RewritedInfo { } }; +} // namespace + class RewriteTensorPointerPass : public TritonRewriteTensorPointerBase { private: diff --git a/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/lib/Dialect/TritonGPU/IR/CMakeLists.txt index 20f6f9851e1b..8477f4dccd20 100644 --- a/lib/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(TritonGPUIR Dialect.cpp Traits.cpp + Types.cpp DEPENDS TritonGPUTableGen diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index bc55e1661e3f..5ff25a0a5eca 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -16,6 +16,21 @@ using namespace mlir::triton::gpu; namespace mlir { namespace triton { +static Type getI1SameShapeFromTensorOrTensorPtr(Type type) { + auto i1Type = IntegerType::get(type.getContext(), 1); + if (auto tensorType = type.dyn_cast()) { + return RankedTensorType::get(tensorType.getShape(), i1Type, + tensorType.getEncoding()); + } else if (auto ptrType = type.dyn_cast()) { + Type pointeeType = ptrType.getPointeeType(); + if (auto tensorType = pointeeType.dyn_cast()) { + return RankedTensorType::get(tensorType.getShape(), i1Type, + tensorType.getEncoding()); + } + } + return Type(); +} + namespace gpu { // TODO: Inheritance of layout attributes @@ -84,12 +99,23 @@ SmallVector getThreadsPerWarp(Attribute layout) { return {4, 8}; if (mmaLayout.isAmpere()) return {8, 4}; + if (mmaLayout.isHopper()) + return {8, 4}; } if (auto mfmaLayout = layout.dyn_cast()) { + unsigned rows, cols; + if (mfmaLayout.getNonKDim() == 32) { + cols = 2; + rows = 32; + } else { + cols = 4; + rows = 16; + } + if (mfmaLayout.getIsTransposed()) { - return {32, 2}; + return {rows, cols}; } else { - return {2, 32}; + return {cols, rows}; } } if (auto sliceLayout = layout.dyn_cast()) { @@ -203,14 +229,28 @@ SmallVector getSizePerThread(Attribute layout) { return {2, 2}; } else if (mmaLayout.isVolta()) { return {1, 2}; + } else if (mmaLayout.isHopper()) { + auto instrShape = mmaLayout.getInstrShape(); + // TODO(thomas): what are those magic numbers? + return SmallVector{instrShape[0] * 4 / 32, instrShape[1] / 4}; } else { llvm_unreachable("Unexpected mma version"); } } else if (auto mfmaLayout = layout.dyn_cast()) { + unsigned rows, cols; + if (mfmaLayout.getNonKDim() == 32) { + rows = 16; + cols = 1; + } else if (mfmaLayout.getNonKDim() == 16) { + rows = 4; + cols = 1; + } else + llvm_unreachable("Unexpected mfma non-k dim"); + if (mfmaLayout.getIsTransposed()) { - return {1, 16}; + return {cols, rows}; } else { - return {16, 1}; + return {rows, cols}; } } else if (auto dotLayout = layout.dyn_cast()) { auto parentLayout = dotLayout.getParent(); @@ -218,7 +258,6 @@ SmallVector getSizePerThread(Attribute layout) { if (auto parentMmaLayout = parentLayout.dyn_cast()) { assert(parentMmaLayout.isAmpere() && "mmaLayout version = 1 is not implemented yet"); - auto parentShapePerCTA = getShapePerCTA(parentLayout); auto opIdx = dotLayout.getOpIdx(); if (opIdx == 0) { return {2, 4}; @@ -229,7 +268,6 @@ SmallVector getSizePerThread(Attribute layout) { return {}; } } else if (parentLayout.isa()) { - auto parentShapePerCTA = getShapePerCTA(parentLayout); auto opIdx = dotLayout.getOpIdx(); if (opIdx == 0) { return {4, 1}; @@ -252,7 +290,7 @@ SmallVector getSizePerThread(Attribute layout) { SmallVector getContigPerThread(Attribute layout) { if (auto mmaLayout = layout.dyn_cast()) { - assert(mmaLayout.isVolta() || mmaLayout.isAmpere()); + assert(mmaLayout.isVolta() || mmaLayout.isAmpere() || mmaLayout.isHopper()); return {1, 2}; } else if (layout.isa()) { return {1, 1}; @@ -301,17 +339,22 @@ SmallVector getThreadsPerCTA(Attribute layout) { } else assert(0 && "Unimplemented usage of MmaEncodingAttr"); } else if (auto mfmaLayout = layout.dyn_cast()) { - threads = {32 * mfmaLayout.getWarpsPerCTA()[0], - 2 * mfmaLayout.getWarpsPerCTA()[1]}; + if (mfmaLayout.getNonKDim() == 32) { + threads = {32 * mfmaLayout.getWarpsPerCTA()[0], + 2 * mfmaLayout.getWarpsPerCTA()[1]}; + } else { + threads = {16 * mfmaLayout.getWarpsPerCTA()[0], + 4 * mfmaLayout.getWarpsPerCTA()[1]}; + } } else { - assert(0 && "Unimplemented usage of getShapePerCTA"); + assert(0 && "Unimplemented usage of getThreadsPerCTA"); } return threads; } -SmallVector getShapePerCTA(Attribute layout, - ArrayRef tensorShape) { +SmallVector getShapePerCTATile(Attribute layout, + ArrayRef tensorShape) { SmallVector shape; if (auto blockedLayout = layout.dyn_cast()) { for (unsigned d = 0, n = blockedLayout.getOrder().size(); d < n; ++d) @@ -319,13 +362,8 @@ SmallVector getShapePerCTA(Attribute layout, blockedLayout.getThreadsPerWarp()[d] * blockedLayout.getWarpsPerCTA()[d]); } else if (auto sliceLayout = layout.dyn_cast()) { - unsigned dim = sliceLayout.getDim(); - auto parent = sliceLayout.getParent(); - for (unsigned d = 0, n = getOrder(parent).size(); d < n; ++d) { - if (d == dim) - continue; - shape.push_back(getShapePerCTA(parent, tensorShape)[d]); - } + shape = getShapePerCTATile(sliceLayout.getParent(), tensorShape); + shape.erase(shape.begin() + sliceLayout.getDim()); } else if (auto mmaLayout = layout.dyn_cast()) { if (mmaLayout.isAmpere()) return {16 * mmaLayout.getWarpsPerCTA()[0], @@ -338,29 +376,36 @@ SmallVector getShapePerCTA(Attribute layout, return {static_cast(tensorShape[0]), static_cast(tensorShape[1])}; } + if (mmaLayout.isHopper()) { + auto instrShape = mmaLayout.getInstrShape(); + return {16 * mmaLayout.getWarpsPerCTA()[0], + instrShape[1] * mmaLayout.getWarpsPerCTA()[1]}; + } assert(0 && "Unexpected MMA layout version found"); } else if (auto mfmaLayout = layout.dyn_cast()) { - return {32 * mfmaLayout.getWarpsPerCTA()[0], - 32 * mfmaLayout.getWarpsPerCTA()[1]}; + auto nonKDim = mfmaLayout.getNonKDim(); + return {nonKDim * mfmaLayout.getWarpsPerCTA()[0], + nonKDim * mfmaLayout.getWarpsPerCTA()[1]}; } else if (auto dotLayout = layout.dyn_cast()) { auto parentLayout = dotLayout.getParent(); assert(parentLayout && "DotOperandEncodingAttr must have a parent"); if (auto parentMmaLayout = parentLayout.dyn_cast()) { assert(parentMmaLayout.isAmpere() && "mmaLayout version = 1 is not implemented yet"); - auto parentShapePerCTA = getShapePerCTA(parentLayout, tensorShape); + auto parentShapePerCTATile = + getShapePerCTATile(parentLayout, tensorShape); auto opIdx = dotLayout.getOpIdx(); if (opIdx == 0) { - return {parentShapePerCTA[0], 16}; + return {parentShapePerCTATile[0], 16}; } else if (opIdx == 1) { - return {16, parentShapePerCTA[1]}; + return {16, parentShapePerCTATile[1]}; } else { assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1"); } } else if (auto parentMfmaLayout = parentLayout.dyn_cast()) { - auto parentShapePerCTA = getShapePerCTA(parentLayout, tensorShape); + auto parentShapePerCTA = getShapePerCTATile(parentLayout, tensorShape); auto opIdx = dotLayout.getOpIdx(); if (opIdx == 0) { @@ -375,11 +420,32 @@ SmallVector getShapePerCTA(Attribute layout, "supported yet"); } } else { - assert(0 && "Unimplemented usage of getShapePerCTA"); + assert(0 && "Unimplemented usage of getShapePerCTATile"); } return shape; } +namespace { + +/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr. + * Erase dim and decrease all values larger than dim by 1. + * Example: order = [0, 2, 4, 3, 1], dim = 2 + * resOrder = [0, 3, 2, 1] + */ +SmallVector eraseOrder(ArrayRef order, unsigned dim) { + unsigned rank = order.size(); + assert(dim < rank && "Invalid dim to erase"); + SmallVector resOrder; + for (unsigned i : order) + if (i < dim) + resOrder.push_back(i); + else if (i > dim) + resOrder.push_back(i - 1); + return resOrder; +} + +} // namespace + SmallVector getOrder(Attribute layout) { if (auto blockedLayout = layout.dyn_cast()) { return SmallVector(blockedLayout.getOrder().begin(), @@ -408,8 +474,199 @@ SmallVector getOrder(Attribute layout) { sharedLayout.getOrder().end()); } else { assert(0 && "Unimplemented usage of getOrder"); - return {}; } + return {}; +}; + +CTALayoutAttr getCTALayout(Attribute layout) { + if (auto blockedLayout = layout.dyn_cast()) + return blockedLayout.getCTALayout(); + else if (auto sliceLayout = layout.dyn_cast()) + return CTALayoutAttr::get(layout.getContext(), getCTAsPerCGA(sliceLayout), + getCTASplitNum(sliceLayout), + getCTAOrder(sliceLayout)); + else if (auto mmaLayout = layout.dyn_cast()) + return mmaLayout.getCTALayout(); + else if (auto dotLayout = layout.dyn_cast()) + return CTALayoutAttr::get(layout.getContext(), getCTAsPerCGA(dotLayout), + getCTASplitNum(dotLayout), + getCTAOrder(dotLayout)); + else if (auto sharedLayout = layout.dyn_cast()) + return sharedLayout.getCTALayout(); + else + assert(0 && "Unimplemented usage of getCTALayout"); + return {}; +} + +SmallVector getCTAsPerCGA(Attribute layout) { + ArrayRef ref; + if (auto blockedLayout = layout.dyn_cast()) + ref = blockedLayout.getCTALayout().getCTAsPerCGA(); + else if (auto sliceLayout = layout.dyn_cast()) { + auto parentCTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent()); + if (parentCTAsPerCGA[sliceLayout.getDim()] == 1) { + parentCTAsPerCGA.erase(parentCTAsPerCGA.begin() + sliceLayout.getDim()); + return parentCTAsPerCGA; + } + /* For getCTAsPerCGA of a slice layout, we have two choices: + * (1) Return CTAsPerCGA of its parent. This is not a perfect solution + * because the rank of the returned CTAsPerCGA does not match the rank of + * tensorShape. + * (2) Get CTAsPerCGA of its parent and erase the sliced dim. This is not a + * perfect solution because the product of the returned CTAsPerCGA might not + * match numCTAs. + * To avoid introducing inconsistencies to the shape and + * layout system, the usage of directly getting CTAsPerCGA of a slice layout + * in which the sliced dim is not 1 is banned. You should always consider + * slice layout as a special case and use getCTAsPerCGA(layout.getParent()) + * in the branch where layout is an instance of SliceEncodingAttr. This is + * inconvenient but safe. + */ + assert(0 && "getCTAsPerCGA for SliceEncodingAttr is not well-defined"); + } else if (auto mmaLayout = layout.dyn_cast()) + ref = mmaLayout.getCTALayout().getCTAsPerCGA(); +#ifdef USE_ROCM + else if (auto mfmaLayout = layout.dyn_cast()) + ref = mfmaLayout.getCTALayout().getCTAsPerCGA(); +#endif + else if (auto dotLayout = layout.dyn_cast()) + return getCTAsPerCGA(dotLayout.getParent()); + else if (auto sharedLayout = layout.dyn_cast()) + ref = sharedLayout.getCTALayout().getCTAsPerCGA(); + else + assert(0 && "Unimplemented usage of getCTAsPerCGA"); + return SmallVector(ref.begin(), ref.end()); +} + +SmallVector getCTASplitNum(Attribute layout) { + SmallVector res; + + if (auto blockedLayout = layout.dyn_cast()) { + res.assign(blockedLayout.getCTALayout().getCTASplitNum().begin(), + blockedLayout.getCTALayout().getCTASplitNum().end()); + } else if (auto sliceLayout = layout.dyn_cast()) { + res = getCTASplitNum(sliceLayout.getParent()); + res.erase(res.begin() + sliceLayout.getDim()); + } else if (auto mmaLayout = layout.dyn_cast()) { + res.assign(mmaLayout.getCTALayout().getCTASplitNum().begin(), + mmaLayout.getCTALayout().getCTASplitNum().end()); +#ifdef USE_ROCM + } else if (auto mfmaLayout = layout.dyn_cast()) { + res.assign(mfmaLayout.getCTALayout().getCTASplitNum().begin(), + mfmaLayout.getCTALayout().getCTASplitNum().end()); +#endif + } else if (auto dotLayout = layout.dyn_cast()) { + res = getCTASplitNum(dotLayout.getParent()); + assert(res.size() == 2 && "Invalid dotLayout"); + + // Do not split CTA in K dimension + dotLayout.getOpIdx() == 0 ? res[1] = 1 : res[0] = 1; + } else if (auto sharedLayout = layout.dyn_cast()) { + res.assign(sharedLayout.getCTALayout().getCTASplitNum().begin(), + sharedLayout.getCTALayout().getCTASplitNum().end()); + } else { + assert(false && "Unimplemented usage of getCTASplitNum"); + } + + return res; +} + +SmallVector getCTAOrder(Attribute layout) { + ArrayRef ref; + if (auto blockedLayout = layout.dyn_cast()) { + ref = blockedLayout.getCTALayout().getCTAOrder(); + } else if (auto sliceLayout = layout.dyn_cast()) { + auto parentCTAOrder = getCTAOrder(sliceLayout.getParent()); + return eraseOrder(parentCTAOrder, sliceLayout.getDim()); + } else if (auto mmaLayout = layout.dyn_cast()) { + ref = mmaLayout.getCTALayout().getCTAOrder(); +#ifdef USE_ROCM + } else if (auto mfmaLayout = layout.dyn_cast()) { + ref = mfmaLayout.getCTALayout().getCTAOrder(); +#endif + } else if (auto dotLayout = layout.dyn_cast()) { + return getCTAOrder(dotLayout.getParent()); + } else if (auto sharedLayout = layout.dyn_cast()) { + ref = sharedLayout.getCTALayout().getCTAOrder(); + } else { + assert(0 && "Unimplemented usage of getCTAOrder"); + } + return SmallVector(ref.begin(), ref.end()); +} + +SmallVector getShapePerCTA(ArrayRef CTASplitNum, + ArrayRef shape) { + unsigned rank = shape.size(); + SmallVector shapePerCTA(rank); + for (unsigned i = 0; i < rank; ++i) { + // This wrapping rule must be consistent with emitCTAOffsetForLayout + unsigned splitNum = std::min(shape[i], CTASplitNum[i]); + shapePerCTA[i] = shape[i] / splitNum; + } + return shapePerCTA; +} + +SmallVector getShapePerCTA(Attribute layout, ArrayRef shape) { + if (auto sharedLayout = layout.dyn_cast()) { + // Special logic for pipeline pass, where shape is 3D and CTALayout is 2D. + // The first dim of shape is numStages. This is a work around, otherwise too + // many places would have to be modified in pipeline pass. Maybe we need to + // refactor this logic in the future. + auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum(); + if (shape.size() == CTASplitNum.size() + 1) { + auto res = getShapePerCTA(CTASplitNum, shape.drop_front()); + res.insert(res.begin(), shape.front()); + return res; + } + } + return getShapePerCTA(getCTASplitNum(layout), shape); +} + +SmallVector getShapePerCTA(Type type) { + auto tensorType = type.cast(); + return getShapePerCTA(tensorType.getEncoding(), tensorType.getShape()); +} + +unsigned getNumWarpsPerCTA(Attribute layout) { + ArrayRef warpsPerCTA; + if (auto blockedLayout = layout.dyn_cast()) + warpsPerCTA = blockedLayout.getWarpsPerCTA(); + else if (auto sliceLayout = layout.dyn_cast()) + return getNumWarpsPerCTA(sliceLayout.getParent()); + else if (auto mmaLayout = layout.dyn_cast()) + warpsPerCTA = mmaLayout.getWarpsPerCTA(); +#ifdef USE_ROCM + else if (auto mfmaLayout = layout.dyn_cast()) + warpsPerCTA = mfmaLayout.getWarpsPerCTA(); +#endif + else if (auto dotLayout = layout.dyn_cast()) + return getNumWarpsPerCTA(dotLayout.getParent()); + else if (auto sharedLayout = layout.dyn_cast()) + assert(0 && "Cannot get numWarps from SharedEncodingAttr"); + else + assert(0 && "Unimplemented usage of getNumWarpsPerCTA"); + return product(warpsPerCTA); +} + +unsigned getNumCTAs(Attribute layout) { + ArrayRef CTAsPerCGA; + if (auto blockedLayout = layout.dyn_cast()) + CTAsPerCGA = blockedLayout.getCTALayout().getCTAsPerCGA(); + else if (auto sliceLayout = layout.dyn_cast()) + return getNumCTAs(sliceLayout.getParent()); + else if (auto mmaLayout = layout.dyn_cast()) + CTAsPerCGA = mmaLayout.getCTALayout().getCTAsPerCGA(); +#ifdef USE_ROCM + else if (auto mfmaLayout = layout.dyn_cast()) + CTAsPerCGA = mfmaLayout.getCTALayout().getCTAsPerCGA(); +#endif + else if (auto dotLayout = layout.dyn_cast()) + return getNumCTAs(dotLayout.getParent()); + else if (auto sharedLayout = layout.dyn_cast()) + CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA(); + else + assert(0 && "Unimplemented usage of getNumCTAs"); + return product(CTAsPerCGA); } bool isaDistributedLayout(Attribute layout) { @@ -426,7 +683,7 @@ bool isSharedEncoding(Value value) { return false; } -bool isExpensiveCat(CatOp cat, Attribute &targetEncoding) { +bool isExpensiveCat(CatOp cat, Attribute targetEncoding) { // If the new elements per thread is less than the old one, we will need to do // convert encoding that goes through shared memory anyway. So we consider it // as expensive. @@ -479,7 +736,7 @@ static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr, bool &value, StringRef desc) { auto boolAttr = attr.dyn_cast(); if (!boolAttr) { - parser.emitError(parser.getNameLoc(), "expected bool type in ") << desc; + parser.emitError(parser.getNameLoc(), "expected an bool type in ") << desc; return failure(); } value = boolAttr.getValue(); @@ -489,7 +746,7 @@ static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr, // parse an array of integers static LogicalResult parseIntArrayAttr(AsmParser &parser, const NamedAttribute &attr, - SmallVector &res, + SmallVector &res, StringRef desc) { auto arrayAttr = attr.getValue().dyn_cast(); if (!arrayAttr) { @@ -531,12 +788,13 @@ BlockedEncodingAttr::getElemsPerThread(ArrayRef shape, auto sizePerThread = getSizePerThread(); auto warpsPerCTA = getWarpsPerCTA(); auto threadsPerWarp = getThreadsPerWarp(); + auto shapePerCTA = getShapePerCTA(*this, shape); assert(rank == sizePerThread.size() && "unexpected rank in BlockedEncodingAttr::getElemsPerThread"); SmallVector elemsPerThread(rank); for (size_t i = 0; i < rank; ++i) { unsigned t = sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i]; - elemsPerThread[i] = ceil(shape[i], t) * sizePerThread[i]; + elemsPerThread[i] = ceil(shapePerCTA[i], t) * sizePerThread[i]; } return elemsPerThread; } @@ -585,14 +843,20 @@ MfmaEncodingAttr::getElemsPerThread(ArrayRef shape, Type eltTy) const { assert(rank == 2 && "Unexpected rank of mfma layout"); SmallVector elemsPerThread(rank); + auto nonKDim = getNonKDim(); + auto elemsPerThreadPerTile = (nonKDim == 16 ? 4 : 16); if (getIsTransposed()) { - unsigned elemsCol = ceil(shape[1], 32 * getWarpsPerCTA()[1]) * 16; - unsigned elemsRow = ceil(shape[0], 32 * getWarpsPerCTA()[0]); + unsigned elemsCol = + ceil(shape[1], nonKDim * getWarpsPerCTA()[1]) * + elemsPerThreadPerTile; + unsigned elemsRow = ceil(shape[0], nonKDim * getWarpsPerCTA()[0]); elemsPerThread[0] = elemsRow; elemsPerThread[1] = elemsCol; } else { - unsigned elemsCol = ceil(shape[1], 32 * getWarpsPerCTA()[1]); - unsigned elemsRow = ceil(shape[0], 32 * getWarpsPerCTA()[0]) * 16; + unsigned elemsCol = ceil(shape[1], nonKDim * getWarpsPerCTA()[1]); + unsigned elemsRow = + ceil(shape[0], nonKDim * getWarpsPerCTA()[0]) * + elemsPerThreadPerTile; elemsPerThread[0] = elemsRow; elemsPerThread[1] = elemsCol; } @@ -603,7 +867,10 @@ SmallVector MmaEncodingAttr::getElemsPerThread(ArrayRef shape, Type eltTy) const { size_t rank = shape.size(); assert(rank == 2 && "Unexpected rank of mma layout"); - assert((isVolta() || isAmpere()) && "Only version 1 and 2 is supported"); + assert((isVolta() || isAmpere() || isHopper()) && + "For MmaEncodingAttr only version 1~3 is supported"); + + auto shapePerCTA = getShapePerCTA(getCTALayout().getCTASplitNum(), shape); SmallVector elemsPerThread(rank); if (isVolta()) { @@ -617,15 +884,24 @@ MmaEncodingAttr::getElemsPerThread(ArrayRef shape, Type eltTy) const { unsigned spwN = fpw[1] * 4 * repN; unsigned wptM = getWarpsPerCTA()[0]; unsigned wptN = getWarpsPerCTA()[1]; - unsigned resM = repM * std::max(1, shape[0] / (spwM * wptM)); - unsigned resN = 2 * repN * std::max(1, shape[1] / (spwN * wptN)); + unsigned resM = repM * std::max(1, shapePerCTA[0] / (spwM * wptM)); + unsigned resN = 2 * repN * std::max(1, shapePerCTA[1] / (spwN * wptN)); elemsPerThread[0] = resM; elemsPerThread[1] = resN; } else if (isAmpere()) { - unsigned elemsRow = ceil(shape[0], 16 * getWarpsPerCTA()[0]) * 2; - unsigned elemsCol = ceil(shape[1], 8 * getWarpsPerCTA()[1]) * 2; + unsigned elemsRow = + ceil(shapePerCTA[0], 16 * getWarpsPerCTA()[0]) * 2; + unsigned elemsCol = + ceil(shapePerCTA[1], 8 * getWarpsPerCTA()[1]) * 2; elemsPerThread[0] = elemsRow; elemsPerThread[1] = elemsCol; + } else if (isHopper()) { + auto wpt = getWarpsPerCTA(); + auto instrMNK = getInstrShape(); + int repM = ceil(shapePerCTA[0], instrMNK[0] * wpt[0]); + int repN = ceil(shapePerCTA[1], instrMNK[1] * wpt[1]); + elemsPerThread[0] = 2 * repM; + elemsPerThread[1] = (instrMNK[1] / 4) * repN; } else { llvm_unreachable("Unexpected mma version"); } @@ -638,6 +914,37 @@ unsigned MfmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, return product(getElemsPerThread(shape, eltTy)); } +unsigned +MmaEncodingAttr::getElemsPerThreadOfOperand(int opIdx, + ArrayRef shape) const { + size_t rank = shape.size(); + assert(rank == 2 && "Unexpected rank of mma layout"); + auto shapePerCTA = getShapePerCTA(*this, shape); + int res = 0; + if (isVolta()) { + llvm_unreachable( + "getElemsPerThreadOfOperand() not supported for version 1"); + } else if (isAmpere()) { + llvm_unreachable( + "getElemsPerThreadOfOperand() not supported for version 2"); + } else if (isHopper()) { + auto wpt = getWarpsPerCTA(); + auto instrMNK = getInstrShape(); + if (opIdx == 0) { + int repM = ceil(shapePerCTA[0], instrMNK[0] * wpt[0]); + int repK = ceil(shapePerCTA[1], instrMNK[2]); + return 8 * repM * repK; + + } else if (opIdx == 1) { + int repK = ceil(shapePerCTA[0], instrMNK[2]); + int repN = ceil(shapePerCTA[1], instrMNK[1] * wpt[1]); + // benzh@ here need more check + return 4 * std::max(instrMNK[1] / 32, 1) * repK * repN; + } + } + return res; +} + unsigned MmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, Type eltTy) const { return product(getElemsPerThread(shape, eltTy)); @@ -677,7 +984,9 @@ SmallVector DotOperandEncodingAttr::getMFMAElemsPerInstr() const { auto mfmaEncoding = getParent().cast(); int64_t nonKDim = mfmaEncoding.getNonKDim(); - int64_t kDim = getKWidth(); + assert(nonKDim == 32 || nonKDim == 16); + int64_t kWidth = getKWidth(); + int64_t kDim = kWidth * (nonKDim == 32 ? 2 : 4); if (getOpIdx() == 0) return {nonKDim, kDim}; else @@ -718,12 +1027,13 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, auto rep = getMFMARep(shape, eltTy); return rep[0] * rep[1]; } + auto shapePerCTA = getShapePerCTA(*this, shape); if (auto mmaParent = getParent().dyn_cast()) { int warpsPerCTAM = mmaParent.getWarpsPerCTA()[0]; int warpsPerCTAN = mmaParent.getWarpsPerCTA()[1]; // A100 if (mmaParent.isAmpere()) { - auto rep = getMMAv2Rep(shape, eltTy.getIntOrFloatBitWidth()); + auto rep = getMMAv2Rep(shapePerCTA, eltTy.getIntOrFloatBitWidth()); if (getOpIdx() == 0) return 4 * rep[0] * rep[1]; if (getOpIdx() == 1) @@ -791,12 +1101,12 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, } } if (auto blockedLayout = getParent().dyn_cast()) { - auto shapePerCTA = getShapePerCTA(blockedLayout); + auto shapePerCTATile = getShapePerCTATile(blockedLayout); auto order = blockedLayout.getOrder(); auto sizePerThread = getSizePerThread(blockedLayout); - int K = getOpIdx() == 0 ? shape[1] : shape[0]; - int otherDim = getOpIdx() == 1 ? shape[1] : shape[0]; + int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0]; + int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0]; bool isM = getOpIdx() == 0; @@ -806,13 +1116,13 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; int sizePerThreadMN = isM ? mSizePerThread : nSizePerThread; - int mShapePerCTA = - order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; - int nShapePerCTA = - order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; - int shapePerCTAMN = isM ? mShapePerCTA : nShapePerCTA; + int mShapePerCTATile = + order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int nShapePerCTATile = + order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int shapePerCTAMNTile = isM ? mShapePerCTATile : nShapePerCTATile; - return K * std::max(otherDim / shapePerCTAMN, 1) * sizePerThreadMN; + return K * std::max(otherDim / shapePerCTAMNTile, 1) * sizePerThreadMN; } llvm_unreachable("unknown dot operand parent layout"); return 0; @@ -832,10 +1142,13 @@ Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) { if (parser.parseGreater().failed()) return {}; - SmallVector sizePerThread; - SmallVector threadsPerWarp; - SmallVector warpsPerCTA; - SmallVector order; + SmallVector sizePerThread; + SmallVector threadsPerWarp; + SmallVector warpsPerCTA; + SmallVector order; + SmallVector CTAsPerCGA; + SmallVector CTASplitNum; + SmallVector CTAOrder; for (const NamedAttribute &attr : dict) { if (attr.getName() == "sizePerThread") { @@ -856,6 +1169,15 @@ Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) { } else if (attr.getName() == "order") { if (parseIntArrayAttr(parser, attr, order, "order").failed()) return {}; + } else if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA, "CTAsPerCGA").failed()) + return {}; + } else if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum, "CTASplitNum").failed()) + return {}; + } else if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder, "CTAOrder").failed()) + return {}; } else { parser.emitError(parser.getNameLoc(), "unexpected key: ") << attr.getName().strref(); @@ -863,9 +1185,12 @@ Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) { } } - auto ret = parser.getChecked( - parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order); - return ret; + auto CTALayout = CTALayoutAttr::get(parser.getContext(), CTAsPerCGA, + CTASplitNum, CTAOrder); + + return parser.getChecked(parser.getContext(), + sizePerThread, threadsPerWarp, + warpsPerCTA, order, CTALayout); } void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { @@ -874,6 +1199,9 @@ void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { << ", threadsPerWarp = [" << getThreadsPerWarp() << "]" << ", warpsPerCTA = [" << getWarpsPerCTA() << "]" << ", order = [" << getOrder() << "]" + << ", CTAsPerCGA = [" << getCTALayout().getCTAsPerCGA() << "]" + << ", CTASplitNum = [" << getCTALayout().getCTASplitNum() << "]" + << ", CTAOrder = [" << getCTALayout().getCTAOrder() << "]" << "}>"; } @@ -892,7 +1220,11 @@ Attribute MmaEncodingAttr::parse(AsmParser &parser, Type type) { unsigned versionMajor = 0; unsigned versionMinor = 0; - SmallVector warpsPerCTA; + SmallVector warpsPerCTA; + SmallVector CTAsPerCGA; + SmallVector CTASplitNum; + SmallVector CTAOrder; + SmallVector instrShape; for (const NamedAttribute &attr : dict) { if (attr.getName() == "versionMajor") { @@ -907,17 +1239,42 @@ Attribute MmaEncodingAttr::parse(AsmParser &parser, Type type) { if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) return {}; } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA, "CTAsPerCGA").failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum, "CTASplitNum").failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder, "CTAOrder").failed()) + return {}; + } + if (attr.getName() == "instrShape") { + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) { + return {}; + } + } } + auto CTALayout = CTALayoutAttr::get(parser.getContext(), CTAsPerCGA, + CTASplitNum, CTAOrder); + return parser.getChecked(parser.getContext(), versionMajor, - versionMinor, warpsPerCTA); + versionMinor, warpsPerCTA, + CTALayout, instrShape); } void MmaEncodingAttr::print(AsmPrinter &printer) const { printer << "<{" << "versionMajor = " << getVersionMajor() << ", " << "versionMinor = " << getVersionMinor() << ", " - << "warpsPerCTA = [" << getWarpsPerCTA() << "]" + << "warpsPerCTA = [" << getWarpsPerCTA() << "], " + << "CTAsPerCGA = [" << getCTALayout().getCTAsPerCGA() << "], " + << "CTASplitNum = [" << getCTALayout().getCTASplitNum() << "], " + << "CTAOrder = [" << getCTALayout().getCTAOrder() << "], " + << "instrShape = [" << getInstrShape() << "]" << "}>"; } @@ -935,8 +1292,11 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) { return {}; unsigned nonKDim = 0; - SmallVector warpsPerCTA; + SmallVector warpsPerCTA; bool isTransposed; + SmallVector CTAsPerCGA; + SmallVector CTASplitNum; + SmallVector CTAOrder; for (const NamedAttribute &attr : dict) { if (attr.getName() == "nonKDim") { @@ -950,17 +1310,35 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) { if (parseBool(parser, attr, isTransposed, "isTransposed").failed()) return {}; } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA, "CTAsPerCGA").failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum, "CTASplitNum").failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder, "CTAOrder").failed()) + return {}; + } } - return parser.getChecked(parser.getContext(), nonKDim, - warpsPerCTA, isTransposed); + auto CTALayout = CTALayoutAttr::get(parser.getContext(), CTAsPerCGA, + CTASplitNum, CTAOrder); + + return parser.getChecked( + parser.getContext(), nonKDim, warpsPerCTA, isTransposed, CTALayout); } void MfmaEncodingAttr::print(AsmPrinter &printer) const { printer << "<{" << "nonKDim = " << getNonKDim() << ", " - << "warpsPerCTA = [" << getWarpsPerCTA() << "]" - << ", isTransposed = " << getIsTransposed() << "}>"; + << "warpsPerCTA = [" << getWarpsPerCTA() << "], " + << "isTransposed = " << getIsTransposed() << ", " + << "CTAsPerCGA = [" << getCTALayout().getCTAsPerCGA() << "], " + << "CTASplitNum = [" << getCTALayout().getCTASplitNum() << "], " + << "CTAOrder = [" << getCTALayout().getCTAOrder() << "]}>"; } //===----------------------------------------------------------------------===// @@ -1003,7 +1381,11 @@ Attribute SharedEncodingAttr::parse(AsmParser &parser, Type type) { unsigned vec = 0; unsigned perPhase = 0; unsigned maxPhase = 0; - SmallVector order; + SmallVector order; + SmallVector CTAsPerCGA; + SmallVector CTASplitNum; + SmallVector CTAOrder; + bool hasLeadingOffset = false; for (const NamedAttribute &attr : dict) { if (attr.getName() == "vec") { @@ -1018,6 +1400,19 @@ Attribute SharedEncodingAttr::parse(AsmParser &parser, Type type) { } else if (attr.getName() == "order") { if (parseIntArrayAttr(parser, attr, order, "order").failed()) return {}; + } else if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA, "CTAsPerCGA").failed()) + return {}; + } else if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum, "CTASplitNum").failed()) + return {}; + } else if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder, "CTAOrder").failed()) + return {}; + } else if (attr.getName() == "hasLeadingOffset") { + if (parseBool(parser, attr, hasLeadingOffset, "hasLeadingOffset") + .failed()) + return {}; } else { parser.emitError(parser.getNameLoc(), "unexpected key: ") << attr.getName().strref(); @@ -1025,16 +1420,24 @@ Attribute SharedEncodingAttr::parse(AsmParser &parser, Type type) { } } + auto CTALayout = CTALayoutAttr::get(parser.getContext(), CTAsPerCGA, + CTASplitNum, CTAOrder); + return parser.getChecked(parser.getContext(), vec, - perPhase, maxPhase, order); + perPhase, maxPhase, order, + CTALayout, hasLeadingOffset); } void SharedEncodingAttr::print(AsmPrinter &printer) const { printer << "<{" - << "vec = " << getVec() << ", perPhase = " << getPerPhase() - << ", maxPhase = " << getMaxPhase() << ", order = [" << getOrder() - << "]" - << "}>"; + << "vec = " << getVec() << ", " + << "perPhase = " << getPerPhase() << ", " + << "maxPhase = " << getMaxPhase() << ", " + << "order = [" << getOrder() << "], " + << "CTAsPerCGA = [" << getCTALayout().getCTAsPerCGA() << "], " + << "CTASplitNum = [" << getCTALayout().getCTASplitNum() << "], " + << "CTAOrder = [" << getCTALayout().getCTAOrder() << "], " + << "hasLeadingOffset = " << getHasLeadingOffset() << "}>"; } //===----------------------------------------------------------------------===// @@ -1043,8 +1446,14 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const { bool MmaEncodingAttr::isVolta() const { return getVersionMajor() == 1; } +bool MmaEncodingAttr::isTuring() const { + return getVersionMajor() == 2 && getVersionMinor() == 1; +} + bool MmaEncodingAttr::isAmpere() const { return getVersionMajor() == 2; } +bool MmaEncodingAttr::isHopper() const { return getVersionMajor() == 3; } + // Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor std::tuple MmaEncodingAttr::decodeVoltaLayoutStates() const { @@ -1095,7 +1504,7 @@ void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const { auto mmaParent = getParent().dyn_cast(); printer << "<{" << "opIdx = " << getOpIdx() << ", parent = " << getParent(); - if (mmaParent && mmaParent.isAmpere()) + if ((mmaParent && mmaParent.isAmpere()) || getParent().isa()) printer << ", kWidth = " << getKWidth(); printer << "}>"; } @@ -1153,11 +1562,11 @@ int DotOperandEncodingAttr::getMMAv1NumOuter(ArrayRef shape) const { } //===----------------------------------------------------------------------===// -// InsertSliceAsyncOp +// InsertSliceOp / InsertSliceAsyncOp //===----------------------------------------------------------------------===// -ParseResult InsertSliceAsyncOp::parse(OpAsmParser &parser, - OperationState &result) { +template +ParseResult parseInsertSliceOp(OpAsmParser &parser, OperationState &result) { SmallVector allOperands; Type srcType, dstType; SMLoc allOperandLoc = parser.getCurrentLocation(); @@ -1176,7 +1585,8 @@ ParseResult InsertSliceAsyncOp::parse(OpAsmParser &parser, int hasMask = 0, hasOther = 0; if (allOperands.size() >= 4) { - operandTypes.push_back(triton::getI1SameShape(srcType)); // mask + operandTypes.push_back( + triton::getI1SameShapeFromTensorOrTensorPtr(srcType)); // mask hasMask = 1; } if (allOperands.size() >= 5) { @@ -1189,24 +1599,43 @@ ParseResult InsertSliceAsyncOp::parse(OpAsmParser &parser, return failure(); // Deduce operand_segment_sizes from the number of the operands. - auto operand_segment_sizesAttrName = - InsertSliceAsyncOp::getOperandSegmentSizesAttrName(result.name); + auto operandSegmentSizesAttrName = + OpT::getOperandSegmentSizesAttrName(result.name); result.addAttribute( - operand_segment_sizesAttrName, + operandSegmentSizesAttrName, parser.getBuilder().getDenseI32ArrayAttr({1, 1, 1, hasMask, hasOther})); return success(); } -void InsertSliceAsyncOp::print(OpAsmPrinter &printer) { +template +void printInsertSliceOp(OpAsmPrinter &printer, OpT insertSliceOp) { printer << " "; - printer << getOperation()->getOperands(); + printer << insertSliceOp.getOperation()->getOperands(); // "operand_segment_sizes" can be deduced, so we don't print it. - printer.printOptionalAttrDict(getOperation()->getAttrs(), - {getOperandSegmentSizesAttrName()}); + printer.printOptionalAttrDict( + insertSliceOp->getAttrs(), + {insertSliceOp.getOperandSegmentSizesAttrName()}); printer << " : "; - printer.printStrippedAttrOrType(getSrc().getType()); + printer.printStrippedAttrOrType(insertSliceOp.getSrc().getType()); printer << " -> "; - printer.printStrippedAttrOrType(getResult().getType()); + printer.printStrippedAttrOrType(insertSliceOp.getDst().getType()); +} + +ParseResult InsertSliceOp::parse(OpAsmParser &parser, OperationState &result) { + return parseInsertSliceOp(parser, result); +} + +void InsertSliceOp::print(OpAsmPrinter &printer) { + printInsertSliceOp(printer, *this); +} + +ParseResult InsertSliceAsyncOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseInsertSliceOp(parser, result); +} + +void InsertSliceAsyncOp::print(OpAsmPrinter &printer) { + printInsertSliceOp(printer, *this); } //===----------------------------------------------------------------------===// @@ -1221,6 +1650,9 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface { if (auto mmaAttr = attr.dyn_cast()) { os << "mma"; return AliasResult::FinalAlias; + } else if (attr.isa()) { + os << "mfma"; + return AliasResult::FinalAlias; } else if (auto sharedAttr = attr.dyn_cast()) { os << "shared"; return AliasResult::FinalAlias; @@ -1256,9 +1688,12 @@ struct TritonGPUInferLayoutInterface SmallVector retOrder(sharedEncoding.getOrder().begin(), sharedEncoding.getOrder().end()); std::reverse(retOrder.begin(), retOrder.end()); + // TODO(Qingyi): Need to check whether CTAOrder should also be reversed. + // This is not a problem for tests where numCTAs = 1. resultEncoding = SharedEncodingAttr::get( getDialect()->getContext(), sharedEncoding.getVec(), - sharedEncoding.getPerPhase(), sharedEncoding.getMaxPhase(), retOrder); + sharedEncoding.getPerPhase(), sharedEncoding.getMaxPhase(), retOrder, + sharedEncoding.getCTALayout(), sharedEncoding.getHasLeadingOffset()); return mlir::success(); } @@ -1281,7 +1716,14 @@ struct TritonGPUInferLayoutInterface inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, Attribute retEncoding, std::optional location) const override { - if (auto dotOpEnc = operandEncoding.dyn_cast()) { + auto mmaRetEncoding = retEncoding.dyn_cast(); + if (mmaRetEncoding && mmaRetEncoding.isHopper()) { + // TODO: support gmma when A/B does not reside in shared memory + if (!operandEncoding.isa()) + return emitOptionalError( + location, "unexpected operand layout for MmaEncodingAttr v3"); + } else if (auto dotOpEnc = + operandEncoding.dyn_cast()) { if (opIdx != dotOpEnc.getOpIdx()) return emitOptionalError(location, "Wrong opIdx"); if (retEncoding != dotOpEnc.getParent()) @@ -1324,6 +1766,20 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op, (srcType.getEncoding().isa() || srcType.getEncoding().isa())) return mlir::failure(); + // for hopper MMAv3 + if (!op.use_empty()) { + bool hasDotUser = false; + for (Operation *dot : op.getResult().getUsers()) + if (isa(dot)) + hasDotUser = true; + + if (hasDotUser) { + if (dstType.getEncoding().isa() && + srcType.getEncoding().isa()) + return mlir::failure(); + } + } + // convert to the same layout -- we can delete if (op->getResultTypes() == op->getOperandTypes()) { rewriter.replaceOp(op, op->getOperands()); @@ -1404,10 +1860,10 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op, auto newArg = rewriter.create( op->getLoc(), newType, extract_slice.getSource()); rewriter.replaceOpWithNewOp( - op, resType, newArg.getResult(), extract_slice.offsets(), - extract_slice.sizes(), extract_slice.strides(), - extract_slice.static_offsets(), extract_slice.static_sizes(), - extract_slice.static_strides()); + op, resType, newArg.getResult(), extract_slice.getOffsets(), + extract_slice.getSizes(), extract_slice.getStrides(), + extract_slice.getStaticOffsets(), extract_slice.getStaticSizes(), + extract_slice.getStaticStrides()); return mlir::success(); } @@ -1481,6 +1937,8 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, //===----------------------------------------------------------------------===// void TritonGPUDialect::initialize() { + registerTypes(); + addAttributes< #define GET_ATTRDEF_LIST #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" @@ -1488,6 +1946,7 @@ void TritonGPUDialect::initialize() { addOperations< #define GET_OP_LIST #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc" >(); addInterfaces(); addInterfaces(); diff --git a/lib/Dialect/TritonGPU/IR/Types.cpp b/lib/Dialect/TritonGPU/IR/Types.cpp new file mode 100644 index 000000000000..77f673cc2766 --- /dev/null +++ b/lib/Dialect/TritonGPU/IR/Types.cpp @@ -0,0 +1,38 @@ +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton::gpu; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/Types.cpp.inc" + +Type TokenType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + int type = 1; + if (parser.parseInteger(type)) + return Type(); + + if (parser.parseGreater()) + return Type(); + + return TokenType::get(parser.getContext(), type); +} + +void TokenType::print(AsmPrinter &printer) const { + printer << "<" << getType() << ">"; +} + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void ::mlir::triton::gpu::TritonGPUDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/TritonGPU/IR/Types.cpp.inc" + >(); +} diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 77aab6018d63..dcab1d44c3b6 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -4,60 +4,62 @@ #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" #include using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; namespace { -using triton::DotOp; -using triton::gpu::BlockedEncodingAttr; -using triton::gpu::ConvertLayoutOp; -using triton::gpu::DotOperandEncodingAttr; -using triton::gpu::MmaEncodingAttr; -using triton::gpu::SliceEncodingAttr; - -int computeCapabilityToMMAVersion(int computeCapability) { - if (computeCapability < 70) { - return 0; - } else if (computeCapability < 80) { - return 1; +using tt::DotOp; +using ttg::BlockedEncodingAttr; +using ttg::ConvertLayoutOp; +using ttg::DotOperandEncodingAttr; +using ttg::MmaEncodingAttr; +using ttg::SliceEncodingAttr; + +// higher mma version is prefered, will fallback to lower version if not +// supported +static int getMMAVersionSafe(int computeCapability, tt::DotOp op) { + int baseVersion = 0; + if (computeCapability < 75) { + baseVersion = 1; } else if (computeCapability < 90) { - return 2; + baseVersion = 2; } else if (computeCapability < 100) { - // FIXME: temporarily add this to pass unis tests - return 2; + baseVersion = 3; } else { - assert(false && "computeCapability > 100 not supported"); - return 3; + assert(false && "computeCapability not supported"); } -} -SmallVector mmaVersionToShapePerWarp(int version) { - if (version == 1) - return {16, 16}; - else if (version == 2) - return {16, 8}; - else { - assert(false && "version not supported"); - return {0, 0}; + for (; baseVersion >= 1; baseVersion--) { + if (supportMMA(op, baseVersion)) { + return baseVersion; + } } + + return 0; } -SmallVector warpsPerTileV2(triton::DotOp dotOp, - const ArrayRef shape, - int numWarps) { +SmallVector +warpsPerTileV2(tt::DotOp dotOp, const ArrayRef shape, int numWarps) { auto filter = [&dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); }; auto slices = mlir::getSlice(dotOp, {filter}); for (Operation *op : slices) - if (isa(op) && (op != dotOp)) + if (isa(op) && (op != dotOp)) return {(unsigned)numWarps, 1}; SmallVector ret = {1, 1}; SmallVector shapePerWarp = {16, 8}; - bool changed = false; + // TODO (@daadaada): double-check. + // original logic in + // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252 + // seems buggy for shape = [32, 16] ? do { - changed = false; if (ret[0] * ret[1] >= numWarps) break; if (shape[0] / shapePerWarp[0] / ret[0] >= @@ -74,7 +76,7 @@ SmallVector warpsPerTileV2(triton::DotOp dotOp, } #ifdef USE_ROCM -SmallVector warpsPerTileMI200(triton::DotOp dotOp, +SmallVector warpsPerTileMI200(tt::DotOp dotOp, const ArrayRef shape, int numWarps) { // TODO: needs to be updated with appropriate shapePerWarp etc. @@ -83,7 +85,7 @@ SmallVector warpsPerTileMI200(triton::DotOp dotOp, }; auto slices = mlir::getSlice(dotOp, filter); for (Operation *op : slices) - if (isa(op) && (op != dotOp)) + if (isa(op) && (op != dotOp)) return {(unsigned)numWarps, 1}; SmallVector tensorShape = {shape[0], shape[1]}; @@ -101,7 +103,7 @@ SmallVector warpsPerTileMI200(triton::DotOp dotOp, ret[0] *= 2; } else ret[1] *= 2; - } else { + } else { ret[1] *= 2; } } while (true); @@ -113,41 +115,87 @@ SmallVector warpsPerTileMI200(triton::DotOp dotOp, return ret; } +SmallVector +warpsPerTileV3(tt::DotOp dotOp, const ArrayRef shape, int numWarps, + const SmallVector &instrShape) { + SetVector slices; + mlir::getForwardSlice(dotOp.getResult(), &slices); + if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != + slices.end()) + return {(unsigned)numWarps, 1}; + + // For MMAv3, the smallest indivisible unit of warp shape is (4, 1). + SmallVector ret = {4, 1}; + SmallVector shapePerWarp = {16, instrShape[1]}; + do { + if (ret[0] * ret[1] >= numWarps) + break; + if (shape[0] > shapePerWarp[0] * ret[0]) { + ret[0] *= 2; + } else { + ret[1] *= 2; + } + } while (true); + return ret; +} + class BlockedToMFMA : public mlir::RewritePattern { int mfmaVersion; public: BlockedToMFMA(mlir::MLIRContext *context, int mfmaVersion) - : mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context), mfmaVersion(mfmaVersion) {} + : mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context), mfmaVersion(mfmaVersion) {} - bool isChainDot(triton::DotOp &dotOp) const { + bool isChainDot(tt::DotOp &dotOp) const { auto filter = [&dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); }; auto slices = mlir::getSlice(dotOp, filter); for (Operation *op : slices) { - if (isa(op) && (op != dotOp)) + if (isa(op) && (op != dotOp)) return true; } return false; } - std::pair chooseMfmaDimensions(triton::DotOp dot, int mfmaVersion) const { - int64_t nonKDim = 32; + /// @brief Choose MFMA instruction parameters + /// @param dot target dot operation + /// @param mfmaVersion + /// @param nonKDim + /// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments + std::pair chooseMfmaDimensions(tt::DotOp dot, + int mfmaVersion, + int64_t nonKDim) const { + // number of matrix elements along k dim per one MFMA intruction int64_t kDim = -1; auto opType = dot.getA().getType().cast(); auto elemType = opType.getElementType(); - if (elemType.isF32()) - kDim = 2; - if (elemType.isF16()) - kDim = 8; - if (elemType.isBF16()) { - if (mfmaVersion == 1) - kDim = 4; - if (mfmaVersion == 2) + if (nonKDim == 32) { + if (elemType.isF32()) + kDim = 2; + if (elemType.isF16()) + kDim = 8; + if (elemType.isBF16()) { + if (mfmaVersion == 1) + kDim = 4; + if (mfmaVersion == 2) + kDim = 8; + } + if (elemType.isInteger(8)) kDim = 8; + } else { + if (elemType.isF32()) + kDim = 4; + if (elemType.isF16()) + kDim = 16; + if (elemType.isBF16()) { + if (mfmaVersion == 1) + kDim = 8; + if (mfmaVersion == 2) + kDim = 16; + } + if (elemType.isInteger(8)) + kDim = 16; } - if (elemType.isInteger(8)) - kDim = 8; assert(kDim != -1); return {nonKDim, kDim}; } @@ -155,20 +203,32 @@ class BlockedToMFMA : public mlir::RewritePattern { mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - auto dotOp = cast(op); + auto dotOp = cast(op); auto oldRetType = dotOp.getResult().getType().cast(); if (!oldRetType.getEncoding() || - !oldRetType.getEncoding().isa()) + !oldRetType.getEncoding().isa()) return failure(); - if (!supportMFMA(dotOp)) + // TODO replace with nonKDim with some heuristic in chooseMfmaDimensions + // function + int64_t externalNonKDim = 32; + + const char *mfmaType = std::getenv("MFMA_TYPE"); + if (mfmaType) { + externalNonKDim = std::stol(mfmaType); + assert(externalNonKDim == 32 || externalNonKDim == 16); + } + + if (!supportMFMA(dotOp, externalNonKDim)) return failure(); + auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding()); + // get MFMA encoding for the given number of warps auto retShape = oldRetType.getShape(); auto mod = op->getParentOfType(); - int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); // operands Value a = dotOp.getA(); @@ -177,46 +237,57 @@ class BlockedToMFMA : public mlir::RewritePattern { auto oldBType = b.getType().cast(); auto ctx = oldAType.getContext(); - triton::gpu::MfmaEncodingAttr mfmaEnc; + ttg::MfmaEncodingAttr mfmaEnc; - auto [nonKDim, kDim] = chooseMfmaDimensions(dotOp, mfmaVersion); + auto [nonKDim, kDim] = + chooseMfmaDimensions(dotOp, mfmaVersion, externalNonKDim); auto warpsPerTile = warpsPerTileMI200(dotOp, retShape, numWarps); bool isTransposed = isChainDot(dotOp); - mfmaEnc = triton::gpu::MfmaEncodingAttr::get( - oldRetType.getContext(), nonKDim, warpsPerTile, isTransposed); + mfmaEnc = ttg::MfmaEncodingAttr::get(oldRetType.getContext(), nonKDim, + warpsPerTile, isTransposed, CTALayout); auto newRetType = RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc); // convert accumulator auto oldAcc = dotOp.getOperand(2); - auto newAcc = rewriter.create( + auto newAcc = rewriter.create( oldAcc.getLoc(), newRetType, oldAcc); auto oldAOrder = oldAType.getEncoding() - .cast() + .cast() .getParent() - .cast() + .cast() .getOrder(); auto oldBOrder = oldBType.getEncoding() - .cast() + .cast() .getParent() - .cast() + .cast() .getOrder(); + // kWidth is a number of consecutive elements per one instruction per one thread + auto kWidth = kDim; + // in mfma 32x32 case argument matrix groups elements in 2 groups + // in mfma 16x16 case argument matrix groups elements in 4 groups + if (nonKDim == 32) { + kWidth /= 2; + } else { + assert(nonKDim == 16); + kWidth /= 4; + } auto newAType = RankedTensorType::get( oldAType.getShape(), oldAType.getElementType(), - triton::gpu::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kDim)); + ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth)); auto newBType = RankedTensorType::get( oldBType.getShape(), oldBType.getElementType(), - triton::gpu::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kDim)); - a = rewriter.create(a.getLoc(), newAType, a); - b = rewriter.create(b.getLoc(), newBType, b); - auto newDot = rewriter.create( + ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth)); + a = rewriter.create(a.getLoc(), newAType, a); + b = rewriter.create(b.getLoc(), newBType, b); + auto newDot = rewriter.create( dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getAllowTF32()); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, oldRetType, newDot.getResult()); return success(); } @@ -226,11 +297,12 @@ class BlockedToMFMA : public mlir::RewritePattern { class BlockedToMMA : public mlir::RewritePattern { int computeCapability; mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding + mutable llvm::SmallVector> dotOpSetVector; + mutable llvm::SmallVector mmaV3InstrNs; static bool bwdFilter(Operation *op) { return op->getNumOperands() == 1 && - (isa(op) || + (isa(op) || op->getDialect()->getTypeID() == mlir::TypeID::get()); } @@ -252,31 +324,112 @@ class BlockedToMMA : public mlir::RewritePattern { public: BlockedToMMA(mlir::MLIRContext *context, int computeCapability) - : mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context), + : mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context), computeCapability(computeCapability) {} + static SmallVector + getWarpsPerTile(tt::DotOp dotOp, const ArrayRef shape, int version, + int numWarps, const SmallVector &instrShape) { + switch (version) { + case 2: + return warpsPerTileV2(dotOp, shape, numWarps); + case 3: + return warpsPerTileV3(dotOp, shape, numWarps, instrShape); + default: + assert(false && "not supported version"); + return {0, 0}; + } + } + + unsigned getMmaV3InstrN(tt::DotOp dotOp, unsigned currN) const { + auto type = dotOp.getResult().getType().cast(); + if (type.getEncoding().isa()) + return currN; + for (size_t i = 0; i < dotOpSetVector.size(); ++i) { + if (dotOpSetVector[i].count(dotOp.getOperation()) > 0) + return mmaV3InstrNs[i]; + } + + SetVector slices; + mlir::getForwardSlice(dotOp.getResult(), &slices); + mlir::getBackwardSlice(dotOp.getOperation(), &slices); + unsigned N = currN; + llvm::SetVector dotOpSet; + for (Operation *iter : slices) { + if (auto nextDotOp = dyn_cast(iter)) { + auto type = nextDotOp.getResult().getType().cast(); + auto AType = nextDotOp.getOperand(0).getType().cast(); + auto shapePerCTA = ttg::getShapePerCTA(type); + auto instrShape = mmaVersionToInstrShape(3, shapePerCTA, AType); + dotOpSet.insert(iter); + if (instrShape[1] < N) + N = instrShape[1]; + } + } + mmaV3InstrNs.push_back(N); + dotOpSetVector.push_back(dotOpSet); + return N; + } + + static Value getMMAv3Operand(Value v, mlir::PatternRewriter &rewriter, + int opIdx) { + auto cvtOp = dyn_cast_or_null(v.getDefiningOp()); + auto arg = cvtOp.getSrc(); + auto argType = arg.getType().cast(); + auto eltType = argType.getElementType(); + assert(argType.getEncoding() && "unexpected tensor type"); + auto newOrder = ttg::getOrder(argType.getEncoding()); + + // MMAv3 with transpose only supports f16 and bf16 data type + // fallback to MMAv3 without transpose for other data types + if (!eltType.isF16() && !eltType.isBF16()) { + if (opIdx == 1) { + newOrder = {0, 1}; + } else { + newOrder = {1, 0}; + } + } + + auto CTALayout = ttg::getCTALayout(argType.getEncoding()); + auto newLayout = ttg::SharedEncodingAttr::get( + argType.getContext(), argType.getShape(), newOrder, CTALayout, + argType.getElementType()); + auto newType = RankedTensorType::get(argType.getShape(), + argType.getElementType(), newLayout); + + return rewriter.create(arg.getLoc(), newType, arg); + } + mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { if (computeCapability < 70) return failure(); - auto dotOp = cast(op); + auto dotOp = cast(op); auto ctx = op->getContext(); // TODO: Check data-types and SM compatibility auto oldRetType = dotOp.getResult().getType().cast(); if (!oldRetType.getEncoding() || - oldRetType.getEncoding().isa()) + oldRetType.getEncoding().isa()) return failure(); - // for FMA, should retain the blocked layout. - int versionMajor = computeCapabilityToMMAVersion(computeCapability); - if (!supportMMA(dotOp, versionMajor)) - return failure(); + auto AType = dotOp.getOperand(0).getType().cast(); + auto BType = dotOp.getOperand(1).getType().cast(); // get MMA encoding for the given number of warps - auto retShape = oldRetType.getShape(); + auto retShapePerCTA = ttg::getShapePerCTA(oldRetType); auto mod = op->getParentOfType(); - int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); + auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding()); + + int versionMajor = getMMAVersionSafe(computeCapability, dotOp); + if (!versionMajor) + return failure(); + + auto instrShape = + mmaVersionToInstrShape(versionMajor, retShapePerCTA, AType); + if (versionMajor == 3) + instrShape[1] = getMmaV3InstrN(dotOp, instrShape[1]); // operands Value a = dotOp.getA(); @@ -284,7 +437,7 @@ class BlockedToMMA : public mlir::RewritePattern { auto oldAType = a.getType().cast(); auto oldBType = b.getType().cast(); - triton::gpu::MmaEncodingAttr mmaEnc; + ttg::MmaEncodingAttr mmaEnc; if (versionMajor == 1) { SetVector aBwdSlices, bBwdSlices; auto isCvt = [](Operation *op) { return isa(op); }; @@ -313,46 +466,55 @@ class BlockedToMMA : public mlir::RewritePattern { if (bOp) isBRow = getCvtArgOrder(bOp)[0] == 1; - mmaEnc = triton::gpu::MmaEncodingAttr::get( - oldRetType.getContext(), versionMajor, numWarps, oldAType.getShape(), - oldBType.getShape(), retShape, isARow, isBRow, mmaV1Counter++); - } else if (versionMajor == 2) { - auto warpsPerTile = warpsPerTileV2(dotOp, retShape, numWarps); - mmaEnc = triton::gpu::MmaEncodingAttr::get( - oldRetType.getContext(), versionMajor, 0 /*versionMinor*/, - warpsPerTile); - } else { - llvm_unreachable("Mma layout only supports versionMajor in {1, 2}"); + mmaEnc = ttg::MmaEncodingAttr::get( + oldRetType.getContext(), versionMajor, numWarps, CTALayout, + instrShape, oldAType.getShape(), oldBType.getShape(), retShapePerCTA, + isARow, isBRow, mmaV1Counter++); + } else if (versionMajor == 2 || versionMajor == 3) { + int versionMinor = computeCapability == 75 ? 1 : 0; + auto warpsPerTile = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, + numWarps, instrShape); + mmaEnc = ttg::MmaEncodingAttr::get(oldRetType.getContext(), versionMajor, + versionMinor, warpsPerTile, CTALayout, + instrShape); } - auto newRetType = - RankedTensorType::get(retShape, oldRetType.getElementType(), mmaEnc); + auto newRetType = RankedTensorType::get( + oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); // convert accumulator auto oldAcc = dotOp.getOperand(2); - auto newAcc = rewriter.create( - oldAcc.getLoc(), newRetType, oldAcc); - // convert operands - int minBitwidth = std::min(computeOrigBitWidth(a), computeOrigBitWidth(b)); - Type minType = IntegerType::get(ctx, minBitwidth); - // convert A operand - auto newAEncoding = triton::gpu::DotOperandEncodingAttr::get( - oldAType.getContext(), 0, newRetType.getEncoding(), - minBitwidth > 0 ? minType : oldAType.getElementType()); - auto newAType = RankedTensorType::get( - oldAType.getShape(), oldAType.getElementType(), newAEncoding); - a = rewriter.create(a.getLoc(), newAType, a); - // convert B operand - auto newBEncoding = triton::gpu::DotOperandEncodingAttr::get( - oldBType.getContext(), 1, newRetType.getEncoding(), - minBitwidth > 0 ? minType : oldBType.getElementType()); - auto newBType = RankedTensorType::get( - oldBType.getShape(), oldBType.getElementType(), newBEncoding); - b = rewriter.create(b.getLoc(), newBType, b); + auto newAcc = rewriter.create(oldAcc.getLoc(), + newRetType, oldAcc); + + if (versionMajor == 3) { + a = getMMAv3Operand(a, rewriter, 0); + b = getMMAv3Operand(b, rewriter, 1); + } else { + + // convert operands + int minBitwidth = + std::min(computeOrigBitWidth(a), computeOrigBitWidth(b)); + Type minType = IntegerType::get(ctx, minBitwidth); + // convert A operand + auto newAEncoding = ttg::DotOperandEncodingAttr::get( + oldAType.getContext(), 0, newRetType.getEncoding(), + minBitwidth > 0 ? minType : oldAType.getElementType()); + auto newAType = RankedTensorType::get( + oldAType.getShape(), oldAType.getElementType(), newAEncoding); + a = rewriter.create(a.getLoc(), newAType, a); + // convert B operand + auto newBEncoding = ttg::DotOperandEncodingAttr::get( + oldBType.getContext(), 1, newRetType.getEncoding(), + minBitwidth > 0 ? minType : oldBType.getElementType()); + auto newBType = RankedTensorType::get( + oldBType.getShape(), oldBType.getElementType(), newBEncoding); + b = rewriter.create(b.getLoc(), newBType, b); + } // convert dot instruction - auto newDot = rewriter.create( - dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getAllowTF32()); + auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, + newAcc, dotOp.getAllowTF32()); - rewriter.replaceOpWithNewOp( - op, oldRetType, newDot.getResult()); + rewriter.replaceOpWithNewOp(op, oldRetType, + newDot.getResult()); return success(); } }; diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index d7187e2216ab..8a2342d3aca6 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(TritonGPUTransforms Coalesce.cpp DecomposeConversions.cpp OptimizeDotOperands.cpp + OptimizeEpilogue.cpp Pipeline.cpp Prefetch.cpp RemoveLayoutConversions.cpp @@ -20,4 +21,6 @@ add_mlir_dialect_library(TritonGPUTransforms TritonAnalysis TritonIR TritonGPUIR + TritonNvidiaGPUIR + MLIRTransformUtils ) diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index 0a2cc1b466af..5ebc88083fca 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -3,6 +3,8 @@ #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "llvm/Support/Debug.h" +#include #include using namespace mlir; @@ -24,20 +26,64 @@ typedef DenseMap> LayoutMap; struct CoalescePass : public TritonGPUCoalesceBase { Attribute getCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis, Value ptr, int numWarps, int threadsPerWarp) { - auto origType = ptr.getType().cast(); - // Get the shape of the tensor. - size_t rank = origType.getRank(); + auto refType = ptr.getType(); + if (refType.isa()) + refType = refType.cast().getPointeeType(); + auto refTensorType = refType.cast(); + + // TODO(Keren): integrate it into AxisInfoAnalysis + // Get axis info + auto queryAxisInfo = [&](const Value &val) -> AxisInfo { + auto valType = val.getType(); + // Tensor pointer + // TODO(Chenggang): encoding for tensor pointers is meaningless, remove + // these later while merging into the GitHub main + if (auto ptrType = valType.dyn_cast()) { + auto tensorTy = ptrType.getPointeeType().dyn_cast(); + assert(tensorTy); + auto makeTensorPtr = getMakeTensorPtrOp(val); + auto order = makeTensorPtr.getOrder(); + auto tileShape = triton::gpu::getShapePerCTA(tensorTy); + size_t rank = order.size(); + auto elemSizeInBytes = + tensorTy.getElementType().getIntOrFloatBitWidth() / 8; + SmallVector contiguity(rank, 1); + SmallVector divisibility(rank, 1); + SmallVector constancy(rank, 1); + // The contiguity in `order[0]` is `tileShape[order[0]]` + // The divisibility in `order[0]` is 16 + // TODO[goostavz]: confirm the legality of it + contiguity[order[0]] = tileShape[order[0]]; + divisibility[order[0]] = 16 * 8 / elemSizeInBytes; + return AxisInfo(contiguity, divisibility, constancy); + } + // Normal cases + assert(valType.isa()); + return *axisInfoAnalysis.getAxisInfo(val); + }; + // Get the contiguity order of `ptr` - auto order = argSort(axisInfoAnalysis.getAxisInfo(ptr)->getContiguity()); + SmallVector order; + if (auto ptrType = ptr.getType().dyn_cast()) { + // Tensor pointer + auto makeTensorPtr = getMakeTensorPtrOp(ptr); + std::copy(makeTensorPtr.getOrder().begin(), + makeTensorPtr.getOrder().end(), std::back_inserter(order)); + } else { + // Normal cases + order = argSort(queryAxisInfo(ptr).getContiguity()); + } + // The desired divisibility is the maximum divisibility // among all dependent pointers who have the same order as - // `ptr` + // `ptr`. + // We only do it for normal tensors of pointers, not tensor pointers. SetVector withSameOrder; withSameOrder.insert(ptr); - if (ptr.getDefiningOp()) + if (refType.isa() && ptr.getDefiningOp()) { for (Operation *op : mlir::multiRootGetSlice(ptr.getDefiningOp())) { for (Value val : op->getResults()) { - if (val.getType() != origType) + if (val.getType() != refTensorType) continue; auto currOrder = argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity()); @@ -45,32 +91,43 @@ struct CoalescePass : public TritonGPUCoalesceBase { withSameOrder.insert(val); } } - int numElems = product(origType.getShape()); + } + + auto shapePerCTA = triton::gpu::getShapePerCTA(refTensorType); + int numElems = product(shapePerCTA); int numThreads = numWarps * threadsPerWarp; int numElemsPerThread = std::max(numElems / numThreads, 1); + + // For tensor of pointers, the element to access is the pointee type; + // while for tensor pointer type (`refType` is directly the final shape), + // the element to access is itself. + auto typeForMem = refTensorType.getElementType().isa() + ? refTensorType.getElementType() + .cast() + .getPointeeType() + : refTensorType.getElementType(); + // Thread tile size depends on memory alignment - SmallVector sizePerThread(rank, 1); - unsigned elemNumBits = triton::getPointeeBitWidth(origType); + SmallVector sizePerThread(refTensorType.getRank(), 1); + unsigned elemNumBits = typeForMem.getIntOrFloatBitWidth(); unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); unsigned perThread = 1; for (Value val : withSameOrder) { - unsigned maxMultipleBytes = - axisInfoAnalysis.getAxisInfo(val)->getDivisibility(order[0]); + auto valInfo = queryAxisInfo(val); + unsigned maxMultipleBytes = valInfo.getDivisibility(order[0]); unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u); unsigned maxContig = - axisInfoAnalysis.getAxisInfo(val)->getContiguity(order[0]); + std::min(valInfo.getContiguity(order[0]), shapePerCTA[order[0]]); unsigned alignment = std::min(maxMultiple, maxContig); unsigned currPerThread = std::min(alignment, 128 / elemNumBits); perThread = std::max(perThread, currPerThread); } sizePerThread[order[0]] = std::min(perThread, numElemsPerThread); - SmallVector dims(rank); - std::iota(dims.begin(), dims.end(), 0); - // create encoding - Attribute encoding = triton::gpu::BlockedEncodingAttr::get( - &getContext(), origType.getShape(), sizePerThread, order, numWarps, - threadsPerWarp); - return encoding; + + auto CTALayout = triton::gpu::getCTALayout(refTensorType.getEncoding()); + return triton::gpu::BlockedEncodingAttr::get( + &getContext(), refTensorType.getShape(), sizePerThread, order, numWarps, + threadsPerWarp, CTALayout); } std::function @@ -78,40 +135,47 @@ struct CoalescePass : public TritonGPUCoalesceBase { int numWarps, int threadsPerWarp) { Attribute encoding = getCoalescedEncoding(axisInfoAnalysis, ptr, numWarps, threadsPerWarp); - return [encoding](Type _type) { - RankedTensorType type = _type.cast(); - return RankedTensorType::get(type.getShape(), type.getElementType(), - encoding); + return [encoding](Type type) { + RankedTensorType tensorType = type.cast(); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); }; } template void coalesceOp(LayoutMap &layoutMap, Operation *op, Value ptr, OpBuilder builder) { - RankedTensorType ty = ptr.getType().template dyn_cast(); - if (!ty) + if (!layoutMap.count(ptr)) return; + + // Convert operands + // For load/store with tensor pointers, we don't have to change the + // operands' type, we do this by changing the outputs' type of + // `make_tensor_ptr` auto convertType = layoutMap.lookup(ptr); - // convert operands SmallVector newArgs; - for (auto v : op->getOperands()) { - auto vTy = v.getType().dyn_cast(); - if (vTy && !vTy.getEncoding().isa()) + for (auto operand : op->getOperands()) { + auto tensorType = operand.getType().dyn_cast(); + if (tensorType && + !tensorType.getEncoding().isa()) newArgs.push_back(builder.create( - op->getLoc(), convertType(v.getType()), v)); + op->getLoc(), convertType(tensorType), operand)); else - newArgs.push_back(v); + newArgs.push_back(operand); } - // convert output types + + // Convert output types SmallVector newTypes; for (auto t : op->getResultTypes()) { - bool is_async = std::is_same::value; - newTypes.push_back(is_async ? t : convertType(t)); + bool isAsync = std::is_same::value; + newTypes.push_back(isAsync ? t : convertType(t)); } - // construct new op with the new encoding + + // Construct new op with the new encoding Operation *newOp = builder.create(op->getLoc(), newTypes, newArgs, op->getAttrs()); - // cast the results back to the original layout + + // Cast the results back to the original layout for (size_t i = 0; i < op->getNumResults(); i++) { Value newResult = newOp->getResult(i); if (newTypes[i] != op->getResultTypes()[i]) { @@ -123,6 +187,25 @@ struct CoalescePass : public TritonGPUCoalesceBase { op->erase(); } + void coalesceMakeTensorPtrOpResult(LayoutMap &layoutMap, Operation *op, + Value ptr, OpBuilder builder) { + if (!layoutMap.count(ptr)) + return; + + // Convert result type + auto convertType = layoutMap.lookup(ptr); + auto ptrType = ptr.getType().cast(); + auto resultTensorType = convertType(ptrType.getPointeeType()); + auto newResultType = + PointerType::get(resultTensorType, ptrType.getAddressSpace()); + + // Build new operation and replace + Operation *newOp = builder.create( + op->getLoc(), newResultType, op->getOperands(), op->getAttrs()); + op->getResult(0).replaceAllUsesWith(newOp->getResult(0)); + op->erase(); + } + void runOnOperation() override { // Run axis info analysis ModuleOp moduleOp = getOperation(); @@ -145,8 +228,13 @@ struct CoalescePass : public TritonGPUCoalesceBase { ptr = op.getPtr(); if (!ptr) return; - RankedTensorType ty = ptr.getType().template dyn_cast(); - if (!ty || !ty.getElementType().isa()) + // We only convert `tensor>` or `tt.ptr>` load/store + bool isPtrTensor = false, isTensorPointer = false; + if (auto tensorType = ptr.getType().dyn_cast()) + isPtrTensor = tensorType.getElementType().isa(); + if (auto ptrType = ptr.getType().dyn_cast()) + isTensorPointer = ptrType.getPointeeType().isa(); + if (!isPtrTensor && !isTensorPointer) return; auto mod = curr->getParentOfType(); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); diff --git a/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp b/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp index 2d293635072e..a26db5f7e1b6 100644 --- a/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp @@ -62,7 +62,9 @@ class TritonGPUDecomposeConversionsPass dstType.getShape(), dstType.getElementType(), triton::gpu::SharedEncodingAttr::get( mod.getContext(), dstDotOp, srcType.getShape(), - triton::gpu::getOrder(srcEncoding), srcType.getElementType())); + triton::gpu::getOrder(srcEncoding), + triton::gpu::getCTALayout(srcEncoding), + srcType.getElementType())); auto tmp = builder.create( cvtOp.getLoc(), tmpType, cvtOp.getOperand()); auto newConvert = builder.create( diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 34bb732e8d5f..43f411bcae7d 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -56,10 +56,13 @@ class ConvertTransConvert : public mlir::RewritePattern { if (!ZEncoding) return mlir::failure(); // new X encoding + // TODO(Qingyi): need to check whether the CTALayout of XEncoding should be + // used here. For tests where numCTAs = 1, this is not a problem since all + // CTALayouts are the same. auto newXOrder = triton::gpu::getOrder(argEncoding); auto newXEncoding = triton::gpu::SharedEncodingAttr::get( getContext(), ZEncoding, XType.getShape(), newXOrder, - XType.getElementType()); + XEncoding.getCTALayout(), XType.getElementType()); auto newXType = RankedTensorType::get(XType.getShape(), XType.getElementType(), newXEncoding); if (XEncoding == newXEncoding) @@ -118,6 +121,9 @@ class MoveOpAfterLayoutConversion : public mlir::RewritePattern { cvtArgOp->getDialect()->getTypeID() != mlir::TypeID::get()) return mlir::failure(); + // not handled in elementwise lowering. + if (isa(cvtArgOp)) + return mlir::failure(); // only considers conversions to dot operand if (!cvtTy.getEncoding().isa()) return mlir::failure(); @@ -143,6 +149,83 @@ class MoveOpAfterLayoutConversion : public mlir::RewritePattern { } }; +// convert(trans(convert(arg))) +// x = convert_layout arg: #distributed -> #shared_x +// y = trans x: #shared_x -> #shared_y +// z = convert_layout y: #shared_y -> #shared_z +class FuseTransHopper : public mlir::RewritePattern { + +public: + FuseTransHopper(mlir::MLIRContext *context) + : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), + 1, context) {} + + LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + if (!op->hasOneUse()) + return mlir::failure(); + auto dstOp = cast(op); + auto tmpOp = + dyn_cast_or_null(dstOp.getSrc().getDefiningOp()); + if (!tmpOp) + return mlir::failure(); + auto srcOp = dyn_cast_or_null( + tmpOp.getSrc().getDefiningOp()); + if (!srcOp) + return mlir::failure(); + auto arg = srcOp.getSrc(); + auto X = tmpOp.getSrc(); + // types + auto argType = arg.getType().cast(); + auto XType = X.getType().cast(); + auto ZType = dstOp.getResult().getType().cast(); + // encodings + auto argEncoding = argType.getEncoding(); + auto XEncoding = + XType.getEncoding().cast(); + auto ZEncoding = + ZType.getEncoding().dyn_cast(); + if (!ZEncoding) + return mlir::failure(); + // new X encoding + auto newXOrder = triton::gpu::getOrder(argEncoding); + + auto dotOp = *op->getUsers().begin(); + if (isa(dotOp)) { + auto dotTy = dotOp->getResult(0).getType().cast(); + auto dotEncoding = + dotTy.getEncoding().dyn_cast(); + auto eltType = XType.getElementType(); + if (!dotEncoding || dotEncoding.getVersionMajor() != 3) + return mlir::failure(); + // MMAv3 with transpose only supports f16 and bf16 data type + // fallback to MMAv3 without transpose for other data types + if (!eltType.isF16() && !eltType.isBF16()) { + if (dstOp.getResult() == dotOp->getOperand(0)) { + newXOrder = {0, 1}; + } else if (dstOp.getResult() == dotOp->getOperand(1)) { + newXOrder = {1, 0}; + } + } + } + + // TODO(Qingyi): need to check whether the CTALayout of XEncoding should be + // used here. For tests where numCTAs = 1, this is not a problem since all + // CTALayouts are the same. + auto newXEncoding = triton::gpu::SharedEncodingAttr::get( + getContext(), XType.getShape(), newXOrder, XEncoding.getCTALayout(), + XType.getElementType()); + auto newXType = RankedTensorType::get(XType.getShape(), + XType.getElementType(), newXEncoding); + + auto newX = rewriter.create(srcOp.getLoc(), + newXType, arg); + rewriter.replaceOpWithNewOp(dstOp, newX); + return mlir::success(); + } +}; + } // namespace #define GEN_PASS_CLASSES @@ -165,10 +248,9 @@ class TritonGPUOptimizeDotOperandsPass mlir::RewritePatternSet patterns(context); patterns.add(context); patterns.add(context); + patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure(); - if (fixupLoops(m).failed()) - signalPassFailure(); } }; diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeEpilogue.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeEpilogue.cpp new file mode 100644 index 000000000000..d3c030d3dba0 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeEpilogue.cpp @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; + +namespace { + +// convert(val) : mma -> blocked +// tt.store(ptr, val, mask, ...) : blocked +// ==> +// convert(ptr) : blocked -> mma +// convert(mask) : blocked -> mma +// tt.store(ptr, val, mask, ...) : mma +// +// Store with mma layout directly +class BypassEpilogueSMEM : public mlir::RewritePattern { + +public: + explicit BypassEpilogueSMEM(mlir::MLIRContext *context) + : mlir::RewritePattern(triton::StoreOp::getOperationName(), 1, context) {} + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + + auto stOp = dyn_cast(op); + if (!stOp) + return mlir::failure(); + Value ptr = stOp.getPtr(); + Value val = stOp.getValue(); + Value mask = stOp.getMask(); + auto ptrType = ptr.getType().dyn_cast(); + auto valType = val.getType().dyn_cast(); + if (!ptrType || !valType || + !ptrType.getEncoding().isa() || + !valType.getEncoding().isa()) + return mlir::failure(); + + auto cvtOp = dyn_cast(val.getDefiningOp()); + if (!cvtOp) + return mlir::failure(); + + if (!cvtOp.getSrc() + .getType() + .cast() + .getEncoding() + .isa()) + return mlir::failure(); + + if (!cvtOp.getResult().hasOneUse()) + return mlir::failure(); + + auto newEncoding = + cvtOp.getOperand().getType().cast().getEncoding(); + + auto newVal = cvtOp.getOperand(); + + auto newPtrType = RankedTensorType::get( + ptrType.getShape(), ptrType.getElementType(), newEncoding); + Value newPtr = rewriter.create( + ptr.getLoc(), newPtrType, ptr); + + Value newMask = mask; + if (mask) { + auto maskType = mask.getType().dyn_cast(); + auto newMaskType = RankedTensorType::get( + maskType.getShape(), maskType.getElementType(), newEncoding); + newMask = rewriter.create( + mask.getLoc(), newMaskType, mask); + } + + rewriter.replaceOpWithNewOp( + stOp, newPtr, newVal, newMask, stOp.getCache(), stOp.getEvict()); + return mlir::success(); + } +}; + +} // namespace + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUOptimizeEpiloguePass + : public TritonGPUOptimizeEpilogueBase { + +public: + TritonGPUOptimizeEpiloguePass() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + mlir::RewritePatternSet patterns(context); + + patterns.add(context); + + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { + signalPassFailure(); + } + } +}; + +std::unique_ptr mlir::createTritonGPUOptimizeEpiloguePass() { + return std::make_unique(); +} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index b337b8a7afd4..db5513d92cfb 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -8,7 +8,10 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/MapVector.h" +#include "llvm/Support/Debug.h" //===----------------------------------------------------------------------===// // This file implements software pipelining for loops. The implementation here @@ -78,7 +81,10 @@ using llvm::MapVector; using namespace mlir; -namespace ttg = triton::gpu; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +/// FIXME(Keren): The pipeline pass shouldn't be aware of nvidia_gpu dialect +namespace ttng = mlir::triton::nvidia_gpu; #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" @@ -104,6 +110,19 @@ void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) { } } +struct ConsumerReleaseInfo { + Value iterVar; + Value stageVar; + Value phaseVar; + Value nextIVVar; + Value stepVar; + Value upperBoundVar; + ttg::CTALayoutAttr CTALayout; + DenseMap consumerStageMap; +}; +typedef DenseMap + ConsumerReleaseMap; + class LoopPipeliner { /// Cache of ForOp and YieldOp related to this pipeliner. scf::ForOp forOp; @@ -122,10 +141,35 @@ class LoopPipeliner { /// load => after extract DenseMap loadsExtract; + /// XXX(Keren): The following are h100 only and disabled + /// load => full barrier arrive + DenseMap loadsBarrierArvOp; + /// load => mbarriers + DenseMap loadsFullBarriers; + DenseMap loadsEmptyBarriers; + /// load => null value or previous load which can share barrier with + DenseMap loadsCanShareBarriers; + /// Maintains the information to emit consumer_release mbarrier_arrive + ConsumerReleaseMap &consumerReleaseMap; + bool hasHopperDot = false; + // XXX(Keren): why the variable name is hopper dot and why do we need this + // check? + void checkHopperDots(SetVector &ops); + // XXX(Keren): it looks more like an optimization to be, not sure if it should + // exist in the base pipeliner + void checkOpShareBarriers(SetVector &ops); + int numLoadsRequireAsyncWait = 0; + int numLoadsRequireMBarrier = 0; + /// Iterator values + Value nextIV; Value pipelineIterIdx; + Value curWaitIdx; + + // Only needed when numLoadsRequireMBarrier > 0 Value loopIterIdx; - Value nextIV; + Value curPhase; + Value curEmptyPhase; /// Yield values SmallVector nextBuffers; @@ -138,9 +182,16 @@ class LoopPipeliner { int numStages; /// Arg indicies - size_t bufferIdx, loadIdx, depArgsBeginIdx, ivIndex; + size_t bufferIdx, loadIdx, depArgsBeginIdx, ivIdx; DenseMap depArgsIdx; + /// XXX(Keren): The mode parameter is hacky, should be refactored + // false: legacy mode as a temporary solution for backward compatibility + // true: new mode for hopper + bool mode; + int numWarps; + int numCTAs; + /// value (in loop) => value at stage N DenseMap> valueMapping; /// loop iter arg => value @@ -204,12 +255,11 @@ class LoopPipeliner { /// Get the load mask for `loadOp`, given the mapped mask `mappedMask` (if /// exists) and the current iteration's `loopCond`. - Value getLoadMask(triton::LoadOp loadOp, Value mappedMask, Value loopCond, + Value getLoadMask(tt::LoadOp loadOp, Value mappedMask, Value loopCond, OpBuilder &builder); /// Return an empty buffer of size - ttg::AllocTensorOp allocateEmptyBuffer(triton::LoadOp loadOp, - OpBuilder &builder); + ttg::AllocTensorOp allocateEmptyBuffer(tt::LoadOp loadOp, OpBuilder &builder); /// Collect all args of the new loop SmallVector collectNewLoopArgs(); @@ -220,15 +270,25 @@ class LoopPipeliner { /// Prefetch the next iteration for `newForOp` void prefetchNextIteration(scf::ForOp newForOp, OpBuilder &builder); + /// Check if curIdx is out of bound and wrap value around if necessary + Value getBoundedIterationValue(OpBuilder &builder, Value curIdx, + Value upperBoundIdx, Value curValue, + Value initValue); + /// Assemble `newForOp`'s yield op void finalizeYield(scf::ForOp newForOp, OpBuilder &builder); public: - LoopPipeliner(scf::ForOp forOp, int numStages) - : forOp(forOp), numStages(numStages) { + LoopPipeliner(scf::ForOp forOp, int numStages, int numWarps, int numCTAs, + bool mode, ConsumerReleaseMap &consumerReleaseMap) + : forOp(forOp), numStages(numStages), numWarps(numWarps), + numCTAs(numCTAs), mode(mode), consumerReleaseMap(consumerReleaseMap) { + // cache yieldOp yieldOp = cast(forOp.getBody()->getTerminator()); } + LoopPipeliner() = delete; + /// Collect loads to pipeline. Return success if we can pipeline this loop LogicalResult initialize(); @@ -252,27 +312,30 @@ LogicalResult LoopPipeliner::collectOps(SetVector &ops) { // We cannot use forOp.walk(...) here because we only want to visit the // operations in the loop body block. Nested blocks are handled separately. for (Operation &op : forOp) - if (auto loadOp = dyn_cast(&op)) { - auto ptr = loadOp.getPtr(); - unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); - - if (auto mask = loadOp.getMask()) - vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); - - auto tensorTy = ptr.getType().dyn_cast(); - if (!tensorTy || tensorTy.getRank() < 2) - continue; - auto ty = tensorTy.getElementType() - .cast() - .getPointeeType(); - unsigned width = vec * ty.getIntOrFloatBitWidth(); - // We do not pipeline all loads for the following reasons: - // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8 and 16. - // 2. It's likely that pipling small loads won't offer much performance - // improvement and may even hurt performance by increasing register - // pressure. - if (width >= 32) + if (auto loadOp = dyn_cast(&op)) { + if (isLoadFromTensorPtr(loadOp)) { ops.insert(loadOp); + } else { + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = + std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = ptr.getType().dyn_cast(); + if (!tensorTy || tensorTy.getRank() < 2) + continue; + auto ty = + tensorTy.getElementType().cast().getPointeeType(); + unsigned width = vec * ty.getIntOrFloatBitWidth(); + // We do not pipeline all loads for the following reasons: + // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8 and 16. + // 2. It's likely that pipling small loads won't offer much performance + // improvement and may even hurt performance by increasing register + // pressure. + if (width >= 32) + ops.insert(loadOp); + } } if (ops.empty()) @@ -324,23 +387,25 @@ LogicalResult LoopPipeliner::checkOpUses(SetVector &ops) { collectDeps(ops, opDeps); for (Operation *op : ops) { - if (auto loadOp = dyn_cast(op)) { + if (auto loadOp = dyn_cast(op)) { // Don't pipeline valid loads that depend on other valid loads // (Because if a valid load depends on another valid load, this load needs // to wait on the other load in the prologue, which is against the point // of the pipeline pass) bool isCandidate = true; for (Operation *other : ops) - if (isa(other)) + if (isa(other)) if (opDeps[op].contains(other->getResult(0))) { isCandidate = false; break; } // We only pipeline loads that have one covert_layout (to dot_op) use // TODO: lift this constraint in the future - if (isCandidate && loadOp.getResult().hasOneUse()) { + if (isCandidate && loadOp.getResult().hasOneUse() && + !isLoadFromTensorPtr(loadOp)) { isCandidate = false; Operation *use = *loadOp.getResult().getUsers().begin(); + Operation *preUse = nullptr; // Advance to the first conversion as long as the use resides in shared // memory and it has a single use itself @@ -351,10 +416,11 @@ LogicalResult LoopPipeliner::checkOpUses(SetVector &ops) { use->getResult(0).getType().dyn_cast(); if (!tensorType.getEncoding().isa()) break; + preUse = use; use = *use->getResult(0).getUsers().begin(); } - if (auto convertLayout = llvm::dyn_cast(use)) + if (auto convertLayout = llvm::dyn_cast(use)) { if (auto tensorType = convertLayout.getResult() .getType() .dyn_cast()) @@ -363,13 +429,43 @@ LogicalResult LoopPipeliner::checkOpUses(SetVector &ops) { isCandidate = true; loadsMapping[loadOp] = convertLayout; } + } else if (preUse && isa(use)) { + isCandidate = false; + // for MMAv3 whose dot take SharedEncoding as operands directly + Operation *post = *loadOp.getResult().getUsers().begin(); + auto newOrder = post->getResult(0) + .getType() + .cast() + .getEncoding() + .cast() + .getOrder(); + auto ty = loadOp.getType().cast(); + auto oldOrder = ttg::getOrder(ty.getEncoding()); + // The operand of MMAv3 is in SharedEncoding and it's order should not + // be changed after FuseTranspositions Pass. So we only pipeline the + // load if the order of the loaded BlockedEncoding is the same as the + // order of the SharedEncoding it is converted to. + // TODO: remove this constraint once the LoadOp supports transpose + // fusion + if (newOrder[0] == oldOrder[0] || newOrder[1] == oldOrder[1]) { + isCandidate = true; + loadsMapping[loadOp] = preUse->getResult(0); + } + } + } else if (isCandidate && mode && isLoadFromTensorPtr(loadOp)) { + loadsMapping[loadOp] = loadOp.getResult(); } else isCandidate = false; if (!isCandidate) invalidOps.insert(loadOp); - else + else { validLoads.insert(loadOp); + if (!isLoadFromTensorPtr(loadOp)) + numLoadsRequireAsyncWait++; + else + numLoadsRequireMBarrier++; + } } } @@ -382,6 +478,67 @@ LogicalResult LoopPipeliner::checkOpUses(SetVector &ops) { return success(); } +void LoopPipeliner::checkHopperDots(SetVector &ops) { + // dots to be pipelined + SetVector dots; + for (Operation &op : forOp) { + if (auto dotOp = dyn_cast(&op)) { + auto resTy = dotOp.getResult().getType().dyn_cast(); + if (auto resEnc = resTy.getEncoding().dyn_cast()) { + if (resEnc && resEnc.isHopper()) { + // Don't pipeline valid dots that depend on ops other than scf.yield + // and scf.for + auto dot = dotOp.getResult(); + bool valid = true; + + // all users of dot should be scf.yield + if (!dot.hasOneUse()) + valid = false; + if (!isa(*dot.getUsers().begin())) + valid = false; + + // C should be a block argument + auto CArg = dotOp.getOperand(2).dyn_cast(); + if (!CArg || !CArg.hasOneUse()) + valid = false; + + if (valid) + dots.insert(dotOp); + } + } + } + } + + hasHopperDot = true; +} + +void LoopPipeliner::checkOpShareBarriers(SetVector &ops) { + // Check if loads can share barriers + auto canShare = [&](Value load0, Value load1) -> bool { + if (!load0.hasOneUse() || !load1.hasOneUse()) + return false; + auto use0 = *load0.getUsers().begin(); + auto use1 = *load1.getUsers().begin(); + if (!use0->hasOneUse() || !use1->hasOneUse()) + return false; + if (*use0->getUsers().begin() != *use1->getUsers().begin()) + return false; + return true; + }; + // XXX(Keren): the logic here is pretty weird and might be incomplete + for (Value loadOp : validLoads) { + Value depLoad; + for (auto oldPair : loadsCanShareBarriers) { + Value oldLoad = oldPair.first; + if (canShare(loadOp, oldLoad)) { + depLoad = oldLoad; + break; + } + } + loadsCanShareBarriers[loadOp] = depLoad; + } +} + void LoopPipeliner::checkOpDeps(SetVector &ops) { SetVector nonImmediateDepArgs; SetVector nonImmediateOps; @@ -413,9 +570,8 @@ void LoopPipeliner::checkOpDeps(SetVector &ops) { } } - // XXX: We could remove the following constraints if we can rematerialize in - // the loop. - // Check if immediateDepArgs and nonImmediateDepArgs are disjoint. + // We could remove the following constraints if we can rematerialize in the + // loop. Check if immediateDepArgs and nonImmediateDepArgs are disjoint. for (auto &[arg, stages] : immediateArgStages) { assert(stages.size() == 1 && "Triton doesn't support an argument provides values for " @@ -485,18 +641,28 @@ void LoopPipeliner::createBufferTypes() { for (auto loadCvt : loadsMapping) { auto loadOp = loadCvt.first; Value cvt = loadCvt.second; - auto dotOpEnc = cvt.getType() - .cast() - .getEncoding() - .cast(); auto ty = loadOp.getType().cast(); SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); bufferShape.insert(bufferShape.begin(), numStages); - unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth(); - auto sharedEnc = - ttg::SharedEncodingAttr::get(ty.getContext(), dotOpEnc, ty.getShape(), - ttg::getOrder(ty.getEncoding()), bitWidth); + auto CTALayout = ttg::getCTALayout(ty.getEncoding()); + Attribute sharedEnc; + if (auto dotOpEnc = cvt.getType() + .cast() + .getEncoding() + .dyn_cast()) { + // MMAv1 and MMAv2 + unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth(); + sharedEnc = ttg::SharedEncodingAttr::get( + ty.getContext(), dotOpEnc, ty.getShape(), + ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth); + } else { + // MMAv3 + sharedEnc = ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(), + ttg::getOrder(ty.getEncoding()), + CTALayout, ty.getElementType()); + } + // FIXME(Keren): block ptr not handled loadsBufferType[loadOp] = RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc); } @@ -525,7 +691,7 @@ int LoopPipeliner::getValueDefStage(Value v, int stage) { return stage; } -ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(triton::LoadOp loadOp, +ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(tt::LoadOp loadOp, OpBuilder &builder) { // Allocate a buffer for each pipelined tensor // shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16> @@ -546,6 +712,11 @@ LogicalResult LoopPipeliner::initialize() { if (checkOpUses(ops).failed()) return failure(); + // XXX(Keren): hopper specific, should be cleaned up + checkHopperDots(ops); + + checkOpShareBarriers(ops); + checkOpDeps(ops); createBufferTypes(); @@ -555,21 +726,21 @@ LogicalResult LoopPipeliner::initialize() { return success(); } -Value LoopPipeliner::getLoadMask(triton::LoadOp loadOp, Value mappedMask, +Value LoopPipeliner::getLoadMask(tt::LoadOp loadOp, Value mappedMask, Value loopCond, OpBuilder &builder) { - Type maskType = triton::getI1SameShape(loadOp.getType()); + Type maskType = tt::getI1SameShape(loadOp.getType()); Value mask = loadOp.getMask(); Value newMask; if (mask) { Value cond = loopCond; if (isa(maskType)) { - cond = builder.create(mask.getLoc(), maskType, loopCond); + cond = builder.create(mask.getLoc(), maskType, loopCond); } newMask = builder.create(mask.getLoc(), mappedMask, cond); } else { if (isa(maskType)) { - newMask = builder.create(loopCond.getLoc(), maskType, - loopCond); + newMask = + builder.create(loopCond.getLoc(), maskType, loopCond); } else { newMask = loopCond; } @@ -585,7 +756,53 @@ void LoopPipeliner::emitPrologue() { setValueMapping(arg, operand.get(), 0); } - // Emit prologue from [0, numStage-1) + // Alloc a vector of MBarriers in size numStages for each load to be pipelined + bool isMcast = false; + for (Value loadOp : validLoads) { + auto load = cast(loadOp.getDefiningOp()); + if (isLoadFromTensorPtr(load)) { + auto loadTy = loadOp.getType().cast(); + auto CTALayout = ttg::CTALayoutAttr::get( + load.getContext(), + /*CTAsPerCGA*/ {static_cast(numCTAs)}, + /*CTASplitNum*/ {1}, + /*CTAOrder*/ {0}); + auto sharedEncoding = ttg::SharedEncodingAttr::get( + load.getContext(), 1, 1, 1, {0}, CTALayout, false); + auto mBarriersTy = RankedTensorType::get( + {numStages}, builder.getIntegerType(64), sharedEncoding); + + if (!loadsCanShareBarriers[loadOp]) { + Value fullBarriers = builder.create( + load.getLoc(), mBarriersTy, 1); + loadsFullBarriers[loadOp] = fullBarriers; + } + auto layout = loadTy.getEncoding(); + auto CTASplitNum = ttg::getCTASplitNum(layout); + auto CTAsPerCGA = ttg::getCTAsPerCGA(layout); + if (CTASplitNum != CTAsPerCGA) { + isMcast = true; + // FIXME: numConsumerThreads could be 32 as well instead of 128 + // incase the consumer is not GMMA + unsigned arriveCnt = ttg::getNumWarpsPerCTA(layout); + if (hasHopperDot) + arriveCnt /= 4; + arriveCnt *= + product(CTAsPerCGA) / product(CTASplitNum); + + Value emptyBarriers = builder.create( + load.getLoc(), mBarriersTy, arriveCnt); + loadsEmptyBarriers[loadOp] = emptyBarriers; + } + } + } + + if (isMcast) { + builder.create(forOp.getLoc(), /*relaxed*/ 1); + builder.create(forOp.getLoc()); + } + + // prologue from [0, numStage-1) Value iv = forOp.getLowerBound(); pipelineIterIdx = builder.create(iv.getLoc(), 0, 32); for (int stage = 0; stage < numStages - 1; ++stage) { @@ -600,33 +817,99 @@ void LoopPipeliner::emitPrologue() { for (Operation *op : orderedDeps) { Operation *newOp = nullptr; if (validLoads.contains(op->getResult(0))) { - auto load = cast(op); + auto load = cast(op); // Allocate empty buffer if (stage == 0) { loadsBuffer[load] = allocateEmptyBuffer(load, builder); loadStageBuffer[load] = {loadsBuffer[load]}; } // load => copy async - if (auto loadOp = llvm::dyn_cast(op)) { + if (auto loadOp = llvm::dyn_cast(op)) { Value newMask = getLoadMask(loadOp, lookupOrDefault(loadOp.getMask(), stage), loopCond, builder); - newOp = builder.create( - op->getLoc(), loadsBuffer[loadOp].getType(), - lookupOrDefault(loadOp.getPtr(), stage), - loadStageBuffer[loadOp][stage], pipelineIterIdx, newMask, - lookupOrDefault(loadOp.getOther(), stage), loadOp.getCache(), - loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0); - builder.create(op->getLoc()); + + if (mode && isLoadFromTensorPtr(loadOp)) { + auto loc = op->getLoc(); + auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3); + Value stageVal = + builder.create(loc, stage, 32); + // producer_acquire + if (loadsEmptyBarriers.count(loadOp)) { + Value emptyBarrier = builder.create( + loc, mBarTy, loadsEmptyBarriers[loadOp], stageVal); + auto trueVal = + builder.create(loc, 1, /*bitWidth*/ 1); + builder.create(loc, emptyBarrier, trueVal); + } + + // producer_commit + Value fullBarrier; + if (!loadsCanShareBarriers[loadOp]) { + fullBarrier = builder.create( + loc, mBarTy, loadsFullBarriers[loadOp], stageVal); + loadsExtract[loadOp] = fullBarrier; + } else { + // Reuse the barrier from previouse load. + fullBarrier = loadsExtract[loadsCanShareBarriers[loadOp]]; + } + + auto loadTy = loadOp.getType().dyn_cast(); + assert(loadTy); + auto CTASplitNum = ttg::getCTASplitNum(loadTy.getEncoding()); + auto shapePerSlice = + ttg::getShapePerCTA(CTASplitNum, loadTy.getShape()); + unsigned elems = + std::accumulate(shapePerSlice.begin(), shapePerSlice.end(), 1, + std::multiplies{}); + elems *= (loadTy.getElementType().getIntOrFloatBitWidth() / 8); + + if (!loadsCanShareBarriers[loadOp]) { + Value _0 = builder.create(loc, 0, 32); + Value threadId = builder.create(loc); + Value pred = builder.create( + loc, arith::CmpIPredicate::eq, threadId, _0); + pred = builder.create(loc, pred, loopCond); + Operation *barrierArvOp = builder.create( + loc, fullBarrier, pred, + /*remoteCtaId*/ nullptr, /*trackAsyncOp*/ false, elems); + loadsBarrierArvOp[loadOp] = barrierArvOp; + } else { + // Increase the transcnt for barrier of previouse load by the + // bytes of current load. + Operation *barrierArvOp = + loadsBarrierArvOp[loadsCanShareBarriers[loadOp]]; + unsigned base_elems = + barrierArvOp->getAttr("txCount").cast().getInt(); + barrierArvOp->setAttr("txCount", + IntegerAttr::get(builder.getIntegerType(32), + base_elems + elems)); + } + newOp = builder.create( + loc, loadsBuffer[loadOp].getType(), + lookupOrDefault(loadOp.getPtr(), stage), + loadStageBuffer[loadOp][stage], pipelineIterIdx, fullBarrier, + newMask, lookupOrDefault(loadOp.getOther(), stage), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile(), + /*axis*/ 0); + } else { + newOp = builder.create( + op->getLoc(), loadsBuffer[loadOp].getType(), + lookupOrDefault(loadOp.getPtr(), stage), + loadStageBuffer[loadOp][stage], pipelineIterIdx, newMask, + lookupOrDefault(loadOp.getOther(), stage), loadOp.getCache(), + loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0); + builder.create(op->getLoc()); + } loadStageBuffer[loadOp].push_back(newOp->getResult(0)); } else llvm_unreachable("This should be LoadOp"); } else { - if (auto loadOp = dyn_cast(op)) { + if (auto loadOp = dyn_cast(op)) { Value newMask = getLoadMask(loadOp, lookupOrDefault(loadOp.getMask(), stage), loopCond, builder); - newOp = builder.create( + newOp = builder.create( loadOp.getLoc(), loadOp.getResult().getType(), lookupOrDefault(loadOp.getPtr(), stage), newMask, lookupOrDefault(loadOp.getOther(), stage), @@ -667,9 +950,9 @@ void LoopPipeliner::emitPrologue() { } // for (int stage = 0; stage < numStages - 1; ++stage) // async.wait & extract_slice - builder.create(validLoads.front().getLoc(), - validLoads.size() * (numStages - 2)); - loopIterIdx = builder.create(iv.getLoc(), 0, 32); + if (numLoadsRequireAsyncWait > 0) + builder.create(validLoads.front().getLoc(), + validLoads.size() * (numStages - 2)); for (Value loadOp : validLoads) { auto bufferType = loadStageBuffer[loadOp][numStages - 1] .getType() @@ -688,19 +971,20 @@ void LoopPipeliner::emitPrologue() { SmallVector{int_attr(1), int_attr(1), int_attr(1)}); loadsExtract[loadOp] = extractSlice; } - // Bump up loopIterIdx, this is used for getting the correct slice for the - // `next` iteration - loopIterIdx = builder.create( - loopIterIdx.getLoc(), loopIterIdx, - builder.create(loopIterIdx.getLoc(), 1, 32)); + curWaitIdx = builder.create(iv.getLoc(), 0, 32); + loopIterIdx = builder.create(iv.getLoc(), 0, 32); + curPhase = builder.create(iv.getLoc(), 0, 1); + curEmptyPhase = builder.create(iv.getLoc(), 1, 1); } void LoopPipeliner::emitEpilogue() { // If there's any outstanding async copies, we need to wait for them. - OpBuilder builder(forOp); - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPointAfter(forOp); - builder.create(forOp.getLoc(), 0); + if (numLoadsRequireAsyncWait > 0) { + OpBuilder builder(forOp); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointAfter(forOp); + builder.create(forOp.getLoc(), 0); + } } SmallVector LoopPipeliner::collectNewLoopArgs() { @@ -714,6 +998,9 @@ SmallVector LoopPipeliner::collectNewLoopArgs() { // (iv at stage numStages - 2) // (pipeline iteration index) // (loop iteration index) + // (wait index) + // (phase index) + // (empty phase index) // We need this to update operands for yield // original block arg => new arg's idx @@ -739,10 +1026,16 @@ SmallVector LoopPipeliner::collectNewLoopArgs() { newLoopArgs.push_back(valueMapping[depArg][numStages - 1]); } - ivIndex = newLoopArgs.size(); + ivIdx = newLoopArgs.size(); newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]); newLoopArgs.push_back(pipelineIterIdx); - newLoopArgs.push_back(loopIterIdx); + newLoopArgs.push_back(curWaitIdx); + if (numLoadsRequireMBarrier > 0) { + newLoopArgs.push_back(loopIterIdx); + newLoopArgs.push_back(curPhase); + newLoopArgs.push_back(curEmptyPhase); + } + return newLoopArgs; } @@ -759,34 +1052,140 @@ scf::ForOp LoopPipeliner::cloneForOp(ArrayRef newLoopArgs, mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + // Loop iteration args + Value upperBound = newForOp.getUpperBound(); + Value step = newForOp.getStep(); + Value curIV = newForOp.getRegionIterArgs()[ivIdx]; + pipelineIterIdx = newForOp.getRegionIterArgs()[ivIdx + 1]; + curWaitIdx = newForOp.getRegionIterArgs()[ivIdx + 2]; + if (numLoadsRequireMBarrier > 0) { + loopIterIdx = newForOp.getRegionIterArgs()[ivIdx + 3]; + curPhase = newForOp.getRegionIterArgs()[ivIdx + 4]; + curEmptyPhase = newForOp.getRegionIterArgs()[ivIdx + 5]; + } + // Clone the loop body, replace original args with args of the new ForOp. - // We want to find cvt ops that match the following pattern: - // %0 = load %ptr - // %1 (dotOperand) = cvt %0 + SmallVector loadsFromTensorPtr; for (Operation &op : forOp.getBody()->without_terminator()) { - if (auto cvtOp = dyn_cast(op)) { + if (auto cvtOp = dyn_cast(op)) { auto result = op.getResult(0); auto cvtDstTy = result.getType().cast(); - if (cvtDstTy.getEncoding().isa()) { - auto it = - std::find(validLoads.begin(), validLoads.end(), op.getOperand(0)); - if (it != validLoads.end()) { + auto it = + std::find(validLoads.begin(), validLoads.end(), op.getOperand(0)); + if (it != validLoads.end()) { + auto loadArgIdx = std::distance(validLoads.begin(), it); + if (cvtDstTy.getEncoding().isa()) { + // We want to find cvt ops that match the following pattern: + // %0 = load %ptr + // %1 (dotOperand) = cvt %0 // We replace the use new load use with a convert layout - auto loadArgIdx = std::distance(validLoads.begin(), it); auto cvt = builder.create( result.getLoc(), cvtDstTy, newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]); mapping.map(result, cvt.getResult()); continue; + } else if (cvtDstTy.getEncoding().isa()) { + // We want to find cvt ops that match the following pattern: + // %0 = load %ptr + // %1 (sharedEncoding) = cvt %0 + // We replace the use new load use with insert_slice_async's result + mapping.map(result, + newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]); + continue; + } + } + } else if (auto loadOp = dyn_cast(op)) { + if (isLoadFromTensorPtr(loadOp)) { + // XXX(Keren): The comparison operator using std::find on tensor ptr + // doesn't work as expected + auto operand = loadOp.getPtr(); + auto tensorTy = + operand.getType().cast().getPointeeType(); + auto loadArgIdx = 0; + for (auto validLoad : validLoads) { + auto defOp = cast(validLoad.getDefiningOp()); + if (isLoadFromTensorPtr(defOp)) { + auto validOperand = defOp.getOperand(0); + auto validTensorTy = + validOperand.getType().cast().getPointeeType(); + if (tensorTy == validTensorTy) + break; + } + loadArgIdx++; } + // consumer_wait, emitted before the first consumer + auto firstConsumer = getFirstUser(loadOp); + mapping.map(loadOp, newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]); + + // If current load can reuse barriers shared by previous load, then we + // do nothing. + if (!loadsCanShareBarriers[loadOp]) { + // emit mbarrier wait before the first consumer of the loaD + OpBuilder mBarBuilder(firstConsumer); + auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3); + Value fullBarrier = mBarBuilder.create( + loadOp.getLoc(), mBarTy, loadsFullBarriers[loadOp], curWaitIdx); + mBarBuilder.create(loadOp.getLoc(), fullBarrier, + curPhase); + } + + loadsFromTensorPtr.push_back(loadOp); + continue; } } cloneWithInferType(builder, &op, mapping); } + for (Value load : loadsFromTensorPtr) { + // consumer_relase, emitted after the last consumer + // 'the last consumer' might be updated in the following Phase_1 since + // some of the consumers might be pipelined. Thus we maintain this + // information in 'consumerReleaseMap' and move the position of + // consumer_release barrier in a seperate Phase_2 in case necessary. + if (loadsEmptyBarriers.count(load)) { + auto users = mapping.lookup(load).getUsers(); + DenseMap consumerStageMap; + for (Operation *user : users) { + // All the stage is initialized to zero before Phase_1, + // since no consumers has been pipelined yet. + consumerStageMap[user] = 0; + } + auto CTALayout = ttg::getCTALayout( + load.getType().cast().getEncoding()); + ConsumerReleaseInfo info{ + loopIterIdx, pipelineIterIdx, curEmptyPhase, curIV, + step, upperBound, CTALayout, consumerStageMap}; + consumerReleaseMap[loadsEmptyBarriers[load]] = info; + } + } + + // Remove redundant conversions + // e.g., %145 = triton_gpu.convert_layout %arg15 : (tensor<128x64xf16, + // #shared1>) -> tensor<128x64xf16, #shared1> + for (Operation &op : newForOp.getBody()->without_terminator()) { + if (auto convert_layout = dyn_cast(op)) { + auto result = op.getResult(0); + auto cvtDstTy = result.getType(); + auto operand = convert_layout.getOperand(); + auto tensorTy = operand.getType(); + if (cvtDstTy == tensorTy) + result.replaceAllUsesWith(operand); + } + } + return newForOp; } +Value LoopPipeliner::getBoundedIterationValue(OpBuilder &builder, Value curIdx, + Value upperBoundIdx, + Value curValue, Value initValue) { + Value cond = builder.create( + curIdx.getLoc(), arith::CmpIPredicate::uge, curIdx, upperBoundIdx); + Value selectValue = builder.create( + curIdx.getLoc(), cond, initValue, curValue); + return selectValue; +} + void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, OpBuilder &builder) { // Map the dep args of the next iteration to the dep args of the current @@ -798,22 +1197,36 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, ++argIdx; } + // Update loop iteration args + Value curIV = newForOp.getRegionIterArgs()[ivIdx]; + pipelineIterIdx = newForOp.getRegionIterArgs()[ivIdx + 1]; + curWaitIdx = newForOp.getRegionIterArgs()[ivIdx + 2]; + if (numLoadsRequireMBarrier > 0) { + loopIterIdx = newForOp.getRegionIterArgs()[ivIdx + 3]; + curPhase = newForOp.getRegionIterArgs()[ivIdx + 4]; + curEmptyPhase = newForOp.getRegionIterArgs()[ivIdx + 5]; + } + // Special handling for iv & loop condition - Value curIV = newForOp.getRegionIterArgs()[ivIndex]; - nextIV = builder.create(newForOp.getInductionVar().getLoc(), - curIV, newForOp.getStep()); - Value nextLoopCond = - builder.create(nextIV.getLoc(), arith::CmpIPredicate::slt, - nextIV, newForOp.getUpperBound()); - - pipelineIterIdx = newForOp.getRegionIterArgs()[ivIndex + 1]; - Value insertSliceIndex = builder.create( - nextIV.getLoc(), pipelineIterIdx, - builder.create(nextIV.getLoc(), numStages, 32)); - loopIterIdx = newForOp.getRegionIterArgs()[ivIndex + 2]; - Value extractSliceIndex = builder.create( - nextIV.getLoc(), loopIterIdx, - builder.create(nextIV.getLoc(), numStages, 32)); + auto idxLoc = curIV.getLoc(); + nextIV = builder.create(idxLoc, curIV, newForOp.getStep()); + Value nextLoopCond = builder.create( + idxLoc, arith::CmpIPredicate::slt, nextIV, newForOp.getUpperBound()); + + // Constants + Value _0 = builder.create(idxLoc, 0, 32); + Value _1 = builder.create(idxLoc, 1, 32); + Value numStagesVal = + builder.create(idxLoc, numStages, 32); + + // nextWaitIdx + Value waitIdxPlusOne = builder.create(idxLoc, curWaitIdx, _1); + Value nextWaitIdx = getBoundedIterationValue( + builder, waitIdxPlusOne, numStagesVal, waitIdxPlusOne, _0); + + // Indices of InsertSliceAsyncOp and ExtractSliceOp + Value insertSliceIndex = pipelineIterIdx; + Value extractSliceIndex = nextWaitIdx; // Prefetch load deps // If a load-dependent instruction that uses a block argument, we @@ -841,11 +1254,11 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, else curMapping.map(forOp.getInductionVar(), nextIV); Operation *nextOp; - if (auto loadOp = dyn_cast(op)) { + if (auto loadOp = dyn_cast(op)) { auto newMask = getLoadMask(loadOp, curMapping.lookupOrDefault(loadOp.getMask()), nextLoopCond, builder); - nextOp = builder.create( + nextOp = builder.create( loadOp.getLoc(), loadOp.getResult().getType(), curMapping.lookupOrDefault(loadOp.getPtr()), newMask, curMapping.lookupOrDefault(loadOp.getOther()), @@ -870,7 +1283,7 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, Operation *nextOp = nullptr; // Update loading mask if (validLoads.contains(op->getResult(0))) { - auto loadOp = llvm::cast(op); + auto loadOp = llvm::cast(op); auto mask = loadOp.getMask(); auto newMask = getLoadMask(loadOp, nextMapping.lookupOrDefault(loadOp.getMask()), @@ -879,20 +1292,87 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, // If mask is defined outside the loop, don't update the map more than // once if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) - nextMapping.map(loadOp.getMask(), newMask); - newMask = nextMapping.lookupOrDefault(mask); + nextMapping.map(mask, newMask); + newMask = nextMapping.lookupOrDefault(loadOp.getMask()); + } + Value insertedVal; + if (mode && isLoadFromTensorPtr(loadOp)) { + auto loc = op->getLoc(); + auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3); + + // producer_acquire + if (loadsEmptyBarriers.count(loadOp)) { + auto ifOp = builder.create(loc, ArrayRef{}, + nextLoopCond, false); + builder.setInsertionPointToStart(ifOp.thenBlock()); + Value emptyBarrier = builder.create( + loc, mBarTy, loadsEmptyBarriers[loadOp], insertSliceIndex); + builder.create(loc, emptyBarrier, + curEmptyPhase); + builder.setInsertionPointAfter(ifOp); + } + + // producer_commit + Value fullBarrier; + if (!loadsCanShareBarriers[loadOp]) { + fullBarrier = builder.create( + loc, mBarTy, loadsFullBarriers[loadOp], insertSliceIndex); + loadsExtract[loadOp] = fullBarrier; + } else { + // Reuse the barrier from previouse load. + fullBarrier = loadsExtract[loadsCanShareBarriers[loadOp]]; + } + + auto loadTy = loadOp.getType().dyn_cast(); + assert(loadTy); + auto CTASplitNum = ttg::getCTASplitNum(loadTy.getEncoding()); + auto shapePerSlice = + ttg::getShapePerCTA(CTASplitNum, loadTy.getShape()); + unsigned elems = std::accumulate( + shapePerSlice.begin(), shapePerSlice.end(), 1, std::multiplies{}); + elems *= (loadTy.getElementType().getIntOrFloatBitWidth() / 8); + if (!loadsCanShareBarriers[loadOp]) { + Value _0 = builder.create(loc, 0, 32); + Value threadId = builder.create(loc); + Value pred = builder.create( + loc, arith::CmpIPredicate::eq, threadId, _0); + pred = builder.create(loc, pred, nextLoopCond); + Operation *barrierArvOp = builder.create( + loc, fullBarrier, pred, + /*remoteCtaId*/ nullptr, + /*trackAsyncOp*/ false, elems); + loadsBarrierArvOp[loadOp] = barrierArvOp; + } else { + // Increase the transcnt for barrier of previouse load by the bytes of + // current load. + Operation *barrierArvOp = + loadsBarrierArvOp[loadsCanShareBarriers[loadOp]]; + unsigned base_elems = + barrierArvOp->getAttr("txCount").cast().getInt(); + barrierArvOp->setAttr( + "txCount", + IntegerAttr::get(builder.getIntegerType(32), base_elems + elems)); + } + insertedVal = builder.create( + loc, loadsBuffer[loadOp].getType(), + nextMapping.lookupOrDefault(loadOp.getPtr()), + newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()], + insertSliceIndex, fullBarrier, newMask, + nextMapping.lookupOrDefault(loadOp.getOther()), loadOp.getCache(), + loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0); + } else { + insertedVal = builder.create( + op->getLoc(), loadsBuffer[loadOp].getType(), + nextMapping.lookupOrDefault(loadOp.getPtr()), + newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()], + insertSliceIndex, newMask, + nextMapping.lookupOrDefault(loadOp.getOther()), loadOp.getCache(), + loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0); + builder.create(op->getLoc()); } - Value insertAsyncOp = builder.create( - op->getLoc(), loadsBuffer[loadOp].getType(), - nextMapping.lookupOrDefault(loadOp.getPtr()), - newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()], - insertSliceIndex, newMask, - nextMapping.lookupOrDefault(loadOp.getOther()), loadOp.getCache(), - loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0); - builder.create(op->getLoc()); - nextBuffers.push_back(insertAsyncOp); + nextBuffers.push_back(insertedVal); // Extract slice - auto bufferType = insertAsyncOp.getType().cast(); + auto bufferType = insertedVal.getType().cast(); auto bufferShape = bufferType.getShape(); auto sliceType = loadsMapping[loadOp].getType().cast(); sliceType = RankedTensorType::get({bufferShape[1], bufferShape[2]}, @@ -900,7 +1380,7 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, loadsBufferType[loadOp].getEncoding()); nextOp = builder.create( - op->getLoc(), sliceType, insertAsyncOp, + op->getLoc(), sliceType, insertedVal, SmallVector{extractSliceIndex, int_attr(0), int_attr(0)}, SmallVector{int_attr(1), @@ -923,20 +1403,43 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, newForOp.getRegionIterArgs()[depArgsIdx[arg]]); // async.wait & extract_slice - Operation *asyncWait = builder.create( - validLoads[0].getLoc(), validLoads.size() * (numStages - 2)); - for (auto it = extractSlices.rbegin(); it != extractSlices.rend(); ++it) { - // move extract_slice after asyncWait - it->getDefiningOp()->moveAfter(asyncWait); - } - - // Bump iteration count - pipelineIterIdx = builder.create( - nextIV.getLoc(), pipelineIterIdx, - builder.create(nextIV.getLoc(), 1, 32)); - loopIterIdx = builder.create( - nextIV.getLoc(), loopIterIdx, - builder.create(nextIV.getLoc(), 1, 32)); + if (numLoadsRequireAsyncWait > 0) { + Operation *asyncWait = builder.create( + validLoads[0].getLoc(), validLoads.size() * (numStages - 2)); + for (auto it = extractSlices.rbegin(); it != extractSlices.rend(); ++it) { + // move extract_slice after asyncWait + it->getDefiningOp()->moveAfter(asyncWait); + } + } + + // Bump pipelineIterIdx + Value pipelineIterIdxPlusOne = + builder.create(idxLoc, pipelineIterIdx, _1); + pipelineIterIdx = + getBoundedIterationValue(builder, pipelineIterIdxPlusOne, numStagesVal, + pipelineIterIdxPlusOne, _0); + + // Bump curWaitIdx + curWaitIdx = nextWaitIdx; + + if (numLoadsRequireMBarrier > 0) { + // Bump loopIterIdx + loopIterIdx = builder.create(idxLoc, loopIterIdx, _1); + + Value _1_1b = builder.create(idxLoc, 1, 1); + + // Flip curPhase + Value nextPhase = builder.create(idxLoc, curPhase, _1_1b); + curPhase = getBoundedIterationValue(builder, waitIdxPlusOne, numStagesVal, + curPhase, nextPhase); + + // Flip curEmptyPhase + Value nextEmptyPhase = + builder.create(idxLoc, curEmptyPhase, _1_1b); + curEmptyPhase = + getBoundedIterationValue(builder, pipelineIterIdxPlusOne, numStagesVal, + curEmptyPhase, nextEmptyPhase); + } } void LoopPipeliner::finalizeYield(scf::ForOp newForOp, OpBuilder &builder) { @@ -948,14 +1451,21 @@ void LoopPipeliner::finalizeYield(scf::ForOp newForOp, OpBuilder &builder) { for (Value nextSlice : extractSlices) yieldValues.push_back(nextSlice); - for (size_t i = depArgsBeginIdx; i < ivIndex; ++i) { + for (size_t i = depArgsBeginIdx; i < ivIdx; ++i) { auto arg = newForOp.getRegionIterArgs()[i]; assert(depArgsMapping.count(arg) && "Missing loop-carried value"); yieldValues.push_back(depArgsMapping[arg]); } + + // Loop iteration args yieldValues.push_back(nextIV); yieldValues.push_back(pipelineIterIdx); - yieldValues.push_back(loopIterIdx); + yieldValues.push_back(curWaitIdx); + if (numLoadsRequireMBarrier > 0) { + yieldValues.push_back(loopIterIdx); + yieldValues.push_back(curPhase); + yieldValues.push_back(curEmptyPhase); + } builder.setInsertionPointToEnd(newForOp.getBody()); builder.create(yieldOp->getLoc(), yieldValues); @@ -973,14 +1483,26 @@ scf::ForOp LoopPipeliner::createNewForOp() { // ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp struct PipelinePass : public TritonGPUPipelineBase { PipelinePass() = default; - PipelinePass(int numStages) { this->numStages = numStages; } + PipelinePass(int numStages, int numWarps, int numCTAs, + int computeCapability) { + this->numStages = numStages; + this->numWarps = numWarps; + this->numCTAs = numCTAs; + this->computeCapability = computeCapability; + } void runOnOperation() override { - int numStages = this->numStages; - - if (numStages <= 1) + // TODO[goostavz]: mode = 0 is temporary for backward compatible, will be + // deprecated after the refactor of pipeline fully gets done + // TODO[goostavz]: When mode = 1, the mask of prefetch insert_slice in the + // prologue is currently not properly provided. Need some second thought on + // the mask definition of InsertSliceOp when the src is ptr + bool mode = + computeCapability >= 90 && ::triton::tools::getBoolEnv("ENABLE_TMA"); + if (this->numStages <= 1) return; + // phase 0: pipeline loads in loops // Pre-processing // we make sure element-wise ops are done *after* the conversion // to dot operands @@ -991,26 +1513,283 @@ struct PipelinePass : public TritonGPUPipelineBase { // auto didPreprocess = // applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + llvm::SmallVector newForOps; // Do the pipelining getOperation()->walk([&](scf::ForOp forOp) -> void { - LoopPipeliner pipeliner(forOp, numStages); - + LoopPipeliner pipeliner(forOp, this->numStages, this->numWarps, + this->numCTAs, mode, consumerReleaseMap); if (pipeliner.initialize().failed()) return; pipeliner.emitPrologue(); scf::ForOp newForOp = pipeliner.createNewForOp(); pipeliner.emitEpilogue(); + newForOps.push_back(newForOp); // Replace the original loop for (unsigned i = 0; i < forOp->getNumResults(); ++i) forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i)); forOp->erase(); }); + + // phase 1: pipeline dots in loops + // A tt.dot suitable for GMMA will be converted to ttg.dot_async. And a + // ttg.DotWaitOp will synchronize it lagging just one iteration, which is + // a hueristic rule. + for (auto forOp : newForOps) + asyncLaunchDots(forOp); + + // phase 2: emit consumer_release (empty barrier arrive) logics in case of + // TMA multicast. + // For each load ops, it is emitted after its last consumer, if the consumer + // is another async op, find its associated sync op. Each async load will be + // emitted with a consumer_release action. The merge of redundant mbarriers + // will be processed in the consequent OptimizeBarriers pass. + for (const auto &item : consumerReleaseMap) + emitConsumerRelease(item.first, item.second, numStages); } + +private: + Value getRemoteCTAId(OpBuilder &b, Location loc, ttg::CTALayoutAttr CTALayout, + Value remoteCTAIdIdx) const; + void updateConsumerReleaseInfo(Operation *oldOp, Operation *newOp, int stage); + void asyncLaunchDots(scf::ForOp forOp); + void emitConsumerRelease(Value mbarTensor, const ConsumerReleaseInfo &info, + int numStages); + + ConsumerReleaseMap consumerReleaseMap; }; + +void PipelinePass::updateConsumerReleaseInfo(Operation *oldOp, Operation *newOp, + int stage) { + for (auto &item : consumerReleaseMap) { + auto &m = item.second.consumerStageMap; + if (m.count(oldOp)) { + m.erase(oldOp); + m[newOp] = stage; + } + + for (Value operand : oldOp->getOperands()) { + Operation *op = operand.getDefiningOp(); + if (op && isa(op)) { + auto cvt = cast(op); + auto src = cvt.getSrc(); + auto srcEncoding = src.getType().cast().getEncoding(); + auto dstEncoding = + cvt.getResult().getType().cast().getEncoding(); + if (srcEncoding == dstEncoding && m.count(op)) { + m.erase(op); + m[newOp] = stage; + } + } + } + } +} + +void PipelinePass::asyncLaunchDots(scf::ForOp forOp) { + Block *loop = forOp.getBody(); + + /// XXX(Keren): Clean up the following duplicate code with checkDotOp + /// dots to be pipelined + SetVector dots; + for (Operation &op : *loop) { + if (auto dotOp = dyn_cast(&op)) { + auto resTy = dotOp.getResult().getType().dyn_cast(); + if (auto resEnc = resTy.getEncoding().dyn_cast()) { + if (resEnc && resEnc.isHopper()) { + // Don't pipeline valid dots that depend on ops other than scf.yield + // and scf.for + auto dot = dotOp.getResult(); + bool valid = true; + + // all users of dot should be scf.yield + if (!dot.hasOneUse()) + valid = false; + if (!isa(*dot.getUsers().begin())) + valid = false; + + // C should be a block argument + auto CArg = dotOp.getOperand(2).dyn_cast(); + if (!CArg || !CArg.hasOneUse()) + valid = false; + + if (valid) + dots.insert(dotOp); + } + } + } + } + + // Early stop: no need to continue if there is no valid dot in the loop. + if (dots.empty()) + return; + + OpBuilder builder(forOp); + + // 0. insert dot_wait after the last dot in the loop + Value dot = dots.back(); + auto loc = dot.getLoc(); + builder.setInsertionPointAfter(dot.getDefiningOp()); + auto dotWait = builder.create(loc, dots.size()); + + // 1. replace Dot with DotAsync + for (size_t idx = 0; idx < dots.size(); ++idx) { + Value dot = dots[idx]; + auto dotOp = cast(dot.getDefiningOp()); + builder.setInsertionPoint(dot.getDefiningOp()); + auto dotAsync = builder.create( + loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32()); + dot.replaceAllUsesWith(dotAsync.getResult()); + updateConsumerReleaseInfo(dot.getDefiningOp(), dotWait, /*stage=*/1); + dot.getDefiningOp()->erase(); + } + + // 2. If there's any outstanding DotAsyncOps, we need to wait for them. + builder.setInsertionPointAfter(forOp); + Value loopNotEmpty = builder.create( + loc, arith::CmpIPredicate::slt, forOp.getLowerBound(), + forOp.getUpperBound()); + // TODO[goostavz]: it's a workaround to put the DotWaitOp in an IfOp for + // a bug in ptxas which mistakenly analysis the control flow and turn the GMMA + // into synchronuous implementation for safety. + // Remove this If once the bug is fixed. + auto ifOp = builder.create(loc, ArrayRef{}, loopNotEmpty, + /*hasElse*/ false); + builder.setInsertionPointToStart(ifOp.thenBlock()); + builder.create(forOp.getLoc(), 0); +} + +Value PipelinePass::getRemoteCTAId(OpBuilder &b, Location loc, + ttg::CTALayoutAttr CTALayout, + Value remoteCTAIdIdx) const { + auto CTAsPerCGA = CTALayout.getCTAsPerCGA(); + auto CTAOrder = CTALayout.getCTAOrder(); + auto CTASplitNum = CTALayout.getCTASplitNum(); + + // Short path when bcastMask is a constant + bool isConstMcastMask = true; + for (unsigned s : CTASplitNum) { + if (s > 1) { + isConstMcastMask = false; + break; + } + } + if (isConstMcastMask) + return remoteCTAIdIdx; + + Value linearCTAId = b.create(loc); + SmallVector multiDimCTAId = + delinearize(b, loc, linearCTAId, CTAsPerCGA, CTAOrder); + auto rank = CTAOrder.size(); + int bcastDim = -1; + for (size_t i = 0; i < rank; ++i) { + if (CTAsPerCGA[i] != CTASplitNum[i]) { + assert(bcastDim < 0 && "bcast in multiple dims is not expected"); + bcastDim = i; + } + } + multiDimCTAId[bcastDim] = remoteCTAIdIdx; + return linearize(b, loc, multiDimCTAId, CTAsPerCGA, CTAOrder); +} + +void PipelinePass::emitConsumerRelease(Value mbarTensor, + const ConsumerReleaseInfo &info, + int numStages) { + Value iterVar = info.iterVar; + Value stage = info.stageVar; + Value phase = info.phaseVar; + Value nextIV = info.nextIVVar; + Value step = info.stepVar; + Value upperBound = info.upperBoundVar; + + const auto &consumerStageMap = info.consumerStageMap; + // find the the last consumer among all the consumers with the largest stage. + SmallVector consumersWithLargestStage; + int maxStage = 0; + for (const auto &it : consumerStageMap) { + if (it.second > maxStage) { + consumersWithLargestStage.clear(); + consumersWithLargestStage.push_back(it.first); + maxStage = it.second; + } else if (it.second == maxStage) { + consumersWithLargestStage.push_back(it.first); + } + } + assert(consumersWithLargestStage.size() > 0); + DenseMap operationId; + consumersWithLargestStage[0]->getBlock()->walk( + [&](Operation *op) { operationId[op] = operationId.size(); }); + size_t maxId = 0; + Operation *lastUserWithLargestStage; + for (Operation *op : consumersWithLargestStage) { + assert(operationId.find(op) != operationId.end()); + size_t userId = operationId[op]; + if (userId > maxId) { + maxId = userId; + lastUserWithLargestStage = op; + } + } + + OpBuilder b(&getContext()); + b.setInsertionPointAfter(lastUserWithLargestStage); + auto loc = lastUserWithLargestStage->getLoc(); + auto maxStageVal = b.create(loc, maxStage, 32); + + // pred = (iterVar >= maxStage) && + // (threadId % (numConsumerThreads / numRemoteCTAs) == 0); + + // [benzh] maybe we can simplify the logics here + auto cmpOp = arith::CmpIPredicate::sge; + if (maxStage == 0) + cmpOp = arith::CmpIPredicate::sgt; + Value pred = b.create(loc, cmpOp, iterVar, maxStageVal); + + Value threadId = b.create(loc); + auto CTAsPerCGA = info.CTALayout.getCTAsPerCGA(); + auto CTASplitNum = info.CTALayout.getCTASplitNum(); + auto numRemoteCTAs = std::accumulate(CTAsPerCGA.begin(), CTAsPerCGA.end(), 1, + std::multiplies{}) / + std::accumulate(CTASplitNum.begin(), CTASplitNum.end(), + 1, std::multiplies{}); + auto numConsumerThreads = + isa(lastUserWithLargestStage) ? 128 : 32; + Value _0 = b.create(loc, 0, 32); + Value numArrives = b.create( + loc, numConsumerThreads / numRemoteCTAs, 32); + pred = b.create( + loc, pred, + b.create( + loc, arith::CmpIPredicate::eq, + b.create(loc, threadId, numArrives), _0)); + // remoteCtaIdIdx = (threadId % numConsumerThreads) / (numConsumerThreads / + // numRemoteCTAs); + Value remoteCTAIdIdx = b.create( + loc, + b.create( + loc, threadId, + b.create(loc, numConsumerThreads, 32)), + numArrives); + Value remoteCTAId = getRemoteCTAId(b, loc, info.CTALayout, remoteCTAIdIdx); + Value emptyBarrier = b.create( + loc, tt::PointerType::get(b.getIntegerType(64), 3), mbarTensor, stage); + + Value newNextIV = b.create(loc, nextIV, step); + Value nextLoopCond = b.create(loc, arith::CmpIPredicate::slt, + newNextIV, upperBound); + auto ifOp = b.create(loc, ArrayRef{}, nextLoopCond, + /*hasElse*/ false); + b.setInsertionPointToStart(ifOp.thenBlock()); + + b.create(loc, emptyBarrier, pred, remoteCTAId, + /*trackAsyncOp*/ false); +} + } // anonymous namespace -std::unique_ptr mlir::createTritonGPUPipelinePass(int numStages) { - return std::make_unique(numStages); +std::unique_ptr mlir::createTritonGPUPipelinePass(int numStages, + int numWarps, + int numCTAs, + int computeCapability) { + return std::make_unique(numStages, numWarps, numCTAs, + computeCapability); } diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index f568023147f2..27e97f41a11e 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -12,11 +12,11 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" - #include using namespace mlir; @@ -119,526 +119,971 @@ class DecomposeDotOperand : public mlir::RewritePattern { } }; -// It's beneficial to move the conversion -// to after the reduce if necessary since it will be -// done on a rank-reduced tensor hence cheaper -class SimplifyReduceCvt : public mlir::RewritePattern { +// +class ConvertDotConvert : public mlir::RewritePattern { public: - explicit SimplifyReduceCvt(mlir::MLIRContext *context) + ConvertDotConvert(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), - 2, context) {} + 1, context) {} - mlir::LogicalResult + LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - if (!llvm::isa(op)) + auto dstOp = cast(op); + auto dotOp = dstOp.getSrc().getDefiningOp(); + if (!dotOp) return mlir::failure(); - auto convert = llvm::cast(op); - triton::ReduceOp reduce; - for (auto &use : convert.getResult().getUses()) { - auto owner = llvm::dyn_cast(use.getOwner()); - if (!owner) { - continue; - } - - // TODO: This only moves conversions from the first argument which is - // fine for argmin/argmax but may not be optimal generally - if (convert.getResult() != owner.getOperands()[0]) { - continue; - } - reduce = owner; - break; - } - if (!reduce) + if (std::distance(dstOp->user_begin(), dstOp->user_end()) != 1 || + std::distance(dotOp->user_begin(), dotOp->user_end()) != 1) return mlir::failure(); - - SmallVector newOperands = reduce.getOperands(); - - newOperands[0] = convert.getOperand(); - auto newEncoding = - newOperands[0].getType().cast().getEncoding(); - - // this may generate unsupported conversions in the LLVM codegen - if (newEncoding.isa() || - newEncoding.isa()) { + auto cvtOp = + dotOp.getOperand(2).getDefiningOp(); + if (!cvtOp) + return mlir::failure(); + if (!cvtOp.getSrc().getDefiningOp()) return failure(); - } - - for (unsigned i = 1; i < newOperands.size(); ++i) { - auto oldTy = newOperands[i].getType().cast(); - RankedTensorType newTy = - RankedTensorType::Builder(oldTy).setEncoding(newEncoding); - - newOperands[i] = rewriter.create( - op->getLoc(), newTy, newOperands[i]); - } - - rewriter.setInsertionPoint(reduce); - auto newReduce = rewriter.create( - op->getLoc(), newOperands, reduce.getAxis()); - auto &newCombineOp = newReduce.getCombineOp(); - rewriter.cloneRegionBefore(reduce.getCombineOp(), newCombineOp, - newCombineOp.end()); - - SmallVector newRet = newReduce.getResult(); - auto oldTypes = reduce.getResult().getType(); - for (unsigned i = 0; i < reduce.getNumOperands(); ++i) { - // it's still beneficial to move the conversion - // to after the reduce if necessary since it will be - // done on a rank-reduced tensor hence cheaper - if (newRet[i].getType() != oldTypes[i]) - newRet[i] = rewriter.create( - op->getLoc(), oldTypes[i], newRet[i]); - } - rewriter.replaceAllUsesWith(reduce.getResult(), newRet); + auto dstTy = dstOp.getResult().getType().cast(); + auto srcTy = cvtOp.getOperand().getType().cast(); + if (dstTy != srcTy) + return mlir::failure(); - return success(); + auto _0f = rewriter.create( + op->getLoc(), dstTy.getElementType(), + rewriter.getZeroAttr(dstTy.getElementType())); + auto _0 = rewriter.create( + op->getLoc(), dotOp.getResult().getType(), _0f); + auto newDot = rewriter.create( + op->getLoc(), dotOp.getResult().getType(), dotOp.getOperand(0), + dotOp.getOperand(1), _0, dotOp.getAllowTF32()); + auto newCvt = rewriter.create( + op->getLoc(), dstTy, newDot.getResult()); + rewriter.replaceOpWithNewOp(op, newCvt, cvtOp.getOperand()); + return mlir::success(); } }; -// Layout conversions can't deduce their return type automatically. -// IIUC they are therefore not handled by DRR right now -class SimplifyConversion : public mlir::RewritePattern { +// Class to propagate layout globally within a function. +// The current algorithm works by analysis the IR and doing a one shot rewrite +// based on the analysis. The algorithm is as follows: +// 1. Find all the anchor ops. These are ops that have a layout we want to +// preserve. +// +// 2. Propagate the layout to every op reachable which is a transitive child of +// an anchor op until we reach a fix point. +// An op can have multiple transitive anchor parents therefore at this stage +// it may have multiple layout associated to it. +// +// 3. Resolve conflicts by deciding which of the multiple layouts the op should +// keep. If one of the parents has a different layout than what is picked a +// convert operation will be inserted. After this stage each value should have +// only one layout associated. +// +// 4. Rewrite the IR by walking the function following dominance order. Since we +// assume the IR is structured we just need to process the regions in the +// correct order. For each op rewrite it using the layout decided by the +// analysis phase. +class LayoutPropagation { public: - explicit SimplifyConversion(mlir::MLIRContext *context) - : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), - 4, context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - if (!llvm::isa(op)) - return mlir::failure(); - auto convert = llvm::cast(op); - return ConvertLayoutOp::canonicalize(convert, rewriter); - } + // Structure to keep track of the layout associated to a value. + struct LayoutInfo { + LayoutInfo(Attribute encoding) { encodings.insert(encoding); } + LayoutInfo() {} + llvm::SmallSetVector encodings; + }; + LayoutPropagation(triton::FuncOp F) : funcOp(F) {} + // Find the anchor ops and set their layout in the data structure. + void initAnchorLayout(); + // Recursively Propagate the layout to all the users of the anchor ops until + // we reach a fix point. + void propagateLayout(); + // Add layouts given in `Info` to the uses of `value`. + SmallVector propagateToUsers(Value value, LayoutInfo &info); + // Set the encoding to all the values and fill out the values with new layout + // in `changed`. + void setEncoding(ValueRange values, LayoutInfo &info, + SmallVector &changed, Operation *op); + // Resolve cases where a value has multiple layouts associated to it. + void resolveConflicts(); + // Rewrite the IR for the full module. + void rewrite(); + // Rewrite the IR for a region. + void rewriteRegion(Region &R); + // Rewrite an op based on the layout picked by the analysis. + Operation *rewriteOp(Operation *op); + // Rewrite a for op based on the layout picked by the analysis. + Operation *rewriteForOp(scf::ForOp forOp); + Operation *rewriteWhileOp(scf::WhileOp whileOp); + Operation *rewriteIfOp(scf::IfOp ifOp); + void rewriteYieldOp(scf::YieldOp yieldOp); + void rewriteConditionOp(scf::ConditionOp conditionOp); + void rewriteReduceToScalar(Operation *reduceOp); + Operation *cloneElementwise(OpBuilder &rewriter, Operation *op, + Attribute encoding); + // Map the original value to the rewritten one. + void map(Value old, Value newV); + // Return the mapped value in the given encoding. This will insert a convert + // if the encoding is different than the encoding decided at resolve time. + Value getValueAs(Value value, Attribute encoding); + // Dump the current stage of layout information. + void dump(); + +private: + // map from value to layout information. + llvm::MapVector layouts; + // map of the values rewrite based on their encoding. + DenseMap, Value> rewriteMapping; + std::vector opToDelete; + triton::FuncOp funcOp; }; -// ----------------------------------------------------------------------------- -// -// ----------------------------------------------------------------------------- +} // namespace -// op(cvt(arg_0), arg_1, ..., arg_n) -// -> cvt(op(arg_0, cvt(arg_1), ..., cvt(arg_n))) -void pushConversionForward(triton::gpu::ConvertLayoutOp cvt, - SetVector &cvtSlices, - PatternSharedInfo &sharedInfo, - mlir::PatternRewriter &rewriter) { - auto srcEncoding = - cvt.getOperand().getType().cast().getEncoding(); - auto dstEncoding = - cvt.getResult().getType().cast().getEncoding(); - IRMapping mapping; - auto op = cvtSlices.front(); - for (Value arg : op->getOperands()) { - if (arg.getDefiningOp() == cvt) - mapping.map(arg, cvt.getOperand()); - else { - auto oldType = arg.getType().cast(); - auto newType = RankedTensorType::get( - oldType.getShape(), oldType.getElementType(), srcEncoding); - auto cvtI = rewriter.create(arg.getLoc(), - newType, arg); - if (Operation *argOp = arg.getDefiningOp()) - cvtI->moveAfter(argOp); - mapping.map(arg, cvtI); +// Look ahead to at the transitive uses and see if there is a convert to mma +// operations. +static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { + SmallVector queue = {op->getResult(0)}; + SetVector forwardSlice; + llvm::SmallDenseSet seen; + while (!queue.empty()) { + Value currentValue = queue.back(); + queue.pop_back(); + getForwardSlice(currentValue, &forwardSlice); + for (Operation *op : forwardSlice) { + if (auto convertOp = dyn_cast(op)) { + if (convertOp.getResult() + .getType() + .cast() + .getEncoding() == encoding) + return true; + } + auto yield = dyn_cast(op); + if (!yield) + continue; + auto forOp = dyn_cast(yield.getOperation()->getParentOp()); + if (!forOp) + continue; + for (OpOperand &operand : yield->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if (def && forwardSlice.count(def) && + (seen.insert(operand.get()).second == true)) + queue.push_back(forOp.getRegionIterArg(operand.getOperandNumber())); + } } } - rewriter.setInsertionPoint(op); - if (op->getNumResults() == 0) { - Operation *newOp = rewriter.clone(*op, mapping); - rewriter.eraseOp(op); - return; + return false; +} + +#ifdef USE_ROCM +// Look ahead to at the transitive uses and see if there is a convert to mfma +// operations. +// TODO: unify with hasConvertToMMATransisitiveUse? +static bool hasConvertToMFMATransisitiveUse(Operation *op, Attribute encoding) { + SmallVector queue = {op->getResult(0)}; + SetVector forwardSlice; + llvm::SmallDenseSet seen; + while (!queue.empty()) { + Value currentValue = queue.back(); + queue.pop_back(); + getForwardSlice(currentValue, &forwardSlice); + for (Operation *op : forwardSlice) { + if (auto convertOp = dyn_cast(op)) { + if (convertOp.getResult() + .getType() + .cast() + .getEncoding() == encoding) + return true; + } + auto yield = dyn_cast(op); + if (!yield) + continue; + auto forOp = dyn_cast(yield.getOperation()->getParentOp()); + if (!forOp) + continue; + for (OpOperand &operand : yield->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if (def && forwardSlice.count(def) && + (seen.insert(operand.get()).second == true)) + queue.push_back(forOp.getRegionIterArg(operand.getOperandNumber())); + } + } } - auto *newOp = cloneWithInferType(rewriter, op, mapping); - auto newType = newOp->getResult(0).getType().cast(); - auto newCvtType = RankedTensorType::get( - newType.getShape(), newType.getElementType(), dstEncoding); - auto newCvt = rewriter.create( - newOp->getLoc(), newCvtType, newOp->getResult(0)); - sharedInfo.cvtsPushedForwardMap[newCvt] = newCvt->getOperand(0).getDefiningOp(); - rewriter.replaceOp(op, newCvt->getResults()); + return false; +} +#endif + +// Return true if the op is an op with a layout we don't want to change. We will +// propagate the layout starting from anchor ops. +static bool isLayoutAnchor(Operation *op) { + if (isa(op)) + return isExpensiveLoadOrStore(op); + if (isa(op)) + return true; + return false; } -// -class MoveConvertOutOfIf : public mlir::RewritePattern { -public: - explicit MoveConvertOutOfIf(mlir::MLIRContext *context) - : mlir::RewritePattern(scf::IfOp::getOperationName(), 2, context) {} +void LayoutPropagation::initAnchorLayout() { + funcOp.walk([&](Operation *op) { + if (isLayoutAnchor(op)) { + for (auto result : op->getResults()) { + if (auto tensorType = result.getType().dyn_cast()) { + // Workaround, don't popagate MMA layout unless there is a convert + // back to mma further down to avoid generating reduction with MMA + // layout that may have lower performance. + // This can be improved with more aggressive backward propagation. + if (tensorType.getEncoding().isa() && + !hasConvertToMMATransisitiveUse(op, tensorType.getEncoding())) + continue; +#ifdef USE_ROCM + // Workaround to not propagate MFMA layout in case there are + // no chained dots MFMA layout is expensive to convert, so we want + // to convert it to something else as soon as possible. + // It saves LDS space in some cases. + // + // TODO: rework this heuristic if we can store MFMA layout directly + // into global memory. + if (tensorType.getEncoding().isa() && + !hasConvertToMFMATransisitiveUse(op, tensorType.getEncoding())) + continue; +#endif + layouts.insert({result, tensorType.getEncoding()}); + } + } + } + }); +} - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ifOp = cast(*op); - // If “scf.if” defines no values, “scf.yield” will be inserted implicitly. - // However, "scf.else" is not required to be present, so we need to check - // if it exists. - auto thenYield = ifOp.thenYield(); - int numOps = thenYield.getNumOperands(); - SmallVector newThenYieldOps = thenYield.getOperands(); - SetVector thenCvts; - SmallVector newRetTypes; - - bool hasElse = !ifOp.getElseRegion().empty(); - - scf::YieldOp elseYield; - SmallVector newElseYieldOps; - SetVector elseCvts; - if (hasElse) { - elseYield = ifOp.elseYield(); - newElseYieldOps = elseYield.getOperands(); +void LayoutPropagation::setEncoding(ValueRange values, LayoutInfo &info, + SmallVector &changed, + Operation *op) { + for (Value value : values) { + if (!value.getType().isa()) + continue; + bool hasChanged = false; + for (auto encoding : info.encodings) { + auto dstEncoding = inferDstEncoding(op, encoding); + if (dstEncoding) + hasChanged |= layouts[value].encodings.insert(*dstEncoding); } + if (hasChanged) + changed.push_back(value); + } +} - IRMapping mapping; - for (size_t i = 0; i < numOps; i++) { - auto thenCvt = dyn_cast_or_null( - thenYield.getOperand(i).getDefiningOp()); - if (hasElse) { - auto elseYield = ifOp.elseYield(); - auto elseCvt = dyn_cast_or_null( - elseYield.getOperand(i).getDefiningOp()); - if (thenCvt && elseCvt && - std::distance(elseCvt->user_begin(), elseCvt->user_end()) == 1 && - std::distance(thenCvt->user_begin(), thenCvt->user_end()) == 1 && - thenCvt.getOperand().getType() == elseCvt.getOperand().getType()) { - // If thenCvt and elseCvt's type are the same, it means a single - // conversion is enough to replace both of them. We can move the - // conversion out of scf.if and replace both thenCvt and elseCvt with - // the new conversion. - mapping.map(thenCvt.getResult(), thenCvt.getOperand()); - thenCvts.insert((Operation *)thenCvt); - newRetTypes.push_back(thenCvt.getOperand().getType()); - mapping.map(elseCvt.getResult(), elseCvt.getOperand()); - elseCvts.insert((Operation *)elseCvt); - } else - // Cannot move out of scf.if because thenCvt != elseCvt - // Moving it out of scf.if will introduce a new conversion - newRetTypes.push_back(thenYield.getOperand(i).getType()); - } else { - if (thenCvt && - std::distance(thenCvt->user_begin(), thenCvt->user_end()) == 1) { - // If there's only a single use of the conversion then we can move it - mapping.map(thenCvt.getResult(), thenCvt.getOperand()); - thenCvts.insert((Operation *)thenCvt); - newRetTypes.push_back(thenCvt.getOperand().getType()); - } else - // Cannot move out of scf.if because either there's another use of - // the conversion or there's no conversion at all - newRetTypes.push_back(thenYield.getOperand(i).getType()); +SmallVector LayoutPropagation::propagateToUsers(Value value, + LayoutInfo &info) { + SmallVector changed; + for (OpOperand &use : value.getUses()) { + Operation *user = use.getOwner(); + if (auto forOp = dyn_cast(user)) { + Value arg = forOp.getRegionIterArgForOpOperand(use); + Value result = forOp.getResultForOpOperand(use); + setEncoding({arg, result}, info, changed, user); + continue; + } + if (auto whileOp = dyn_cast(user)) { + Value arg = whileOp.getBeforeArguments()[use.getOperandNumber()]; + setEncoding({arg}, info, changed, user); + continue; + } + if (auto yieldOp = dyn_cast(user)) { + auto parent = yieldOp->getParentOp(); + SmallVector valuesToPropagate; + if (isa(parent)) + valuesToPropagate.push_back(parent->getResult(use.getOperandNumber())); + if (auto forOp = dyn_cast(parent)) + valuesToPropagate.push_back( + forOp.getRegionIterArg(use.getOperandNumber())); + if (auto whileOp = dyn_cast(parent)) { + valuesToPropagate.push_back( + whileOp.getBeforeArguments()[use.getOperandNumber()]); + valuesToPropagate.push_back( + whileOp->getOperand(use.getOperandNumber())); } + if (isa(parent)) + setEncoding(valuesToPropagate, info, changed, user); + continue; } - if (mapping.getValueMap().empty()) - return mlir::failure(); - - auto newIfOp = rewriter.create(ifOp.getLoc(), newRetTypes, - ifOp.getCondition(), hasElse); - auto rematerialize = [&](Block *block, SetVector &cvts) { - for (Operation &op : block->getOperations()) { - if (cvts.contains(&op)) { - if (mapping.contains(op.getOperand(0))) - mapping.map(op.getResult(0), mapping.lookup(op.getOperand(0))); + if (auto conditionOp = dyn_cast(user)) { + auto whileOp = cast(conditionOp->getParentOp()); + // Skip arg 0 as it is the condition. + unsigned argIndex = use.getOperandNumber() - 1; + Value afterArg = whileOp.getAfterArguments()[argIndex]; + Value result = whileOp->getResult(argIndex); + setEncoding({afterArg, result}, info, changed, user); + continue; + } + // Workaround: don't propagate through truncI + if (isa(user)) + continue; + if (user->hasTrait() || + user->hasTrait() || + isa(user)) { +#ifdef USE_ROCM + if (auto convertOp = dyn_cast(user)) { + if (triton::gpu::isSharedEncoding(convertOp.getResult()) || + triton::gpu::isSharedEncoding(convertOp.getOperand())) continue; - } - rewriter.clone(op, mapping); } - }; - rewriter.setInsertionPointToEnd(newIfOp.thenBlock()); - rematerialize(ifOp.thenBlock(), thenCvts); - if (hasElse) { - rewriter.setInsertionPointToEnd(newIfOp.elseBlock()); - rematerialize(ifOp.elseBlock(), elseCvts); +#endif + setEncoding(user->getResults(), info, changed, user); + continue; } + } + return changed; +} + +void LayoutPropagation::propagateLayout() { + SmallVector queue; + for (auto it : layouts) { + queue.push_back(it.first); + } + while (!queue.empty()) { + Value currentValue = queue.back(); + LayoutInfo info = layouts[currentValue]; + queue.pop_back(); + SmallVector changed = propagateToUsers(currentValue, info); + queue.insert(queue.end(), changed.begin(), changed.end()); + } +} - rewriter.setInsertionPointAfter(newIfOp); - SmallVector newRetValues = newIfOp.getResults(); - for (size_t i = 0; i < numOps; i++) { - if (newIfOp.getResult(i).getType() != ifOp.getResult(i).getType()) { - newRetValues[i] = rewriter.create( - newIfOp.getLoc(), ifOp.getResult(i).getType(), - newIfOp.getResult(i)); +void LayoutPropagation::resolveConflicts() { + for (auto &it : layouts) { + LayoutInfo &info = it.second; + if (info.encodings.size() <= 1) + continue; + // Hacky resolve, prefer block encoding. + // TODO: add a proper heuristic. + Attribute encoding = *info.encodings.begin(); + for (Attribute e : info.encodings) { + if (e.isa()) { + encoding = e; + break; } } + info.encodings.clear(); + info.encodings.insert(encoding); + } +} - rewriter.replaceOp(op, newRetValues); - return mlir::success(); +void LayoutPropagation::dump() { + for (auto it : layouts) { + llvm::errs() << "Value: "; + OpPrintingFlags flags; + flags.skipRegions(); + it.first.print(llvm::errs(), flags); + llvm::errs() << " \n encoding:\n"; + for (auto encoding : it.second.encodings) { + encoding.print(llvm::errs()); + llvm::errs() << "\n"; + } + llvm::errs() << "--\n"; } -}; +} -// -class RematerializeForward : public mlir::RewritePattern { - PatternSharedInfo &sharedInfo; +void LayoutPropagation::rewrite() { rewriteRegion(funcOp->getRegion(0)); } -public: - explicit RematerializeForward(mlir::MLIRContext *context, PatternSharedInfo &sharedInfo) - : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), - 1, context), sharedInfo(sharedInfo) {} +static bool reduceToScalar(Operation *op) { + // For reductions returning a scalar we can change the src encoding without + // affecting the output. + return isa(op) && + !op->getResultTypes()[0].isa(); +} - mlir::LogicalResult - matchAndRewrite(mlir::Operation *cvtOp, - mlir::PatternRewriter &rewriter) const override { - auto cvt = dyn_cast(*cvtOp); - auto srcEncoding = - cvt.getOperand().getType().cast().getEncoding(); - auto dstEncoding = - cvt.getResult().getType().cast().getEncoding(); - if (srcEncoding.isa() || - dstEncoding.isa()) - return failure(); - // heuristics for flash attention - if (srcEncoding.isa()) - return failure(); - // For cases like: - // %0 = convert_layout %arg0 - // We should try to move %0 out of scf.for first, if it couldn't be moved - // out additional conversions will be added to the loop body. - if (!cvt.getOperand().getDefiningOp() && - isa(cvt->getParentOp())) - return failure(); +void LayoutPropagation::rewriteRegion(Region ®ion) { + SmallVector queue = {®ion}; + while (!queue.empty()) { + Region *currentRegion = queue.back(); + queue.pop_back(); + for (Operation &op : currentRegion->getOps()) { + bool needRewrite = false; + SmallVector results = op.getResults(); + for (Value result : results) { + auto it = layouts.find(result); + // If we haven't mapped this value skip. + if (it == layouts.end()) + continue; + LayoutInfo &info = it->second; + assert(info.encodings.size() == 1 && + "we should have resolved to a single encoding"); + auto encoding = result.getType().cast().getEncoding(); + // If the encoding is already what we want skip. + if (encoding == *info.encodings.begin()) + continue; + needRewrite = true; + } + if (needRewrite) { + Operation *newOp = rewriteOp(&op); + for (Region &R : newOp->getRegions()) + queue.push_back(&R); + } else if (auto yieldOp = dyn_cast(&op)) { + rewriteYieldOp(yieldOp); + } else if (auto conditionOp = dyn_cast(&op)) { + rewriteConditionOp(conditionOp); + } else if (reduceToScalar(&op)) { + rewriteReduceToScalar(&op); + } else { + // If we don't need to rewrite the op we still need to remap the + // operands. + for (OpOperand &operand : op.getOpOperands()) { + auto it = layouts.find(operand.get()); + if (it == layouts.end()) + continue; + Attribute encoding = + operand.get().getType().cast().getEncoding(); + Value newOperand = getValueAs(operand.get(), encoding); + op.setOperand(operand.getOperandNumber(), newOperand); + } + for (Region &R : op.getRegions()) + queue.push_back(&R); + } + } + } + for (Operation *op : llvm::reverse(opToDelete)) + op->erase(); +} - SetVector cvtSlices; - auto filter = [&](Operation *op) { - return op->getBlock() == cvt->getBlock() && - !isa(op) && - !(isa(op) && - !op->getResult(0).getType().isa()); - }; - mlir::getForwardSlice(cvt.getResult(), &cvtSlices, {filter}); - if (cvtSlices.empty()) - return failure(); +void LayoutPropagation::map(Value old, Value newV) { + rewriteMapping[{old, newV.getType().cast().getEncoding()}] = + newV; +} - for (Operation *op : cvtSlices) { - // don't rematerialize anything expensive - if (isExpensiveToRemat(op, srcEncoding)) - return failure(); - // don't rematerialize non-element-wise - if (!op->hasTrait() && - !op->hasTrait() && - !isa(op)) - return failure(); - // don't rematerialize if it adds an extra conversion that can't - // be removed - for (Value arg : op->getOperands()) { - Operation *argOp = arg.getDefiningOp(); - SetVector processed; - SetVector layout; - llvm::MapVector toConvert; - int numAddedConvs = simulateBackwardRematerialization( - argOp, processed, layout, toConvert, srcEncoding); - if (argOp && !isa(argOp) && - cvtSlices.count(argOp) == 0 && numAddedConvs > 0) - return failure(); - } +Value LayoutPropagation::getValueAs(Value value, Attribute encoding) { + if (auto tensorType = value.getType().dyn_cast()) { + Value rewrittenValue; + auto layoutIt = layouts.find(value); + if (layoutIt == layouts.end()) { + rewrittenValue = value; + } else { + assert(layoutIt->second.encodings.size() == 1 && + "we should have resolved to a single encoding"); + Attribute encodingPicked = *(layoutIt->second.encodings.begin()); + if (encodingPicked == tensorType.getEncoding()) + rewrittenValue = value; + else + rewrittenValue = rewriteMapping[{value, encodingPicked}]; } + assert(rewrittenValue); + if (rewrittenValue.getType().cast().getEncoding() == + encoding) + return rewrittenValue; + OpBuilder rewriter(value.getContext()); + rewriter.setInsertionPointAfterValue(rewrittenValue); + auto tmpType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + Value converted = rewriter.create( + value.getLoc(), tmpType, rewrittenValue); + // TODO: we could cache the conversion. + return converted; + } + return value; +} - // Call SimplifyReduceCvt instead of the general push conversion forward - if (isa(cvtSlices.front())) - return failure(); +Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter, + Operation *op, + Attribute encoding) { + Operation *newOp = rewriter.clone(*op); + for (OpOperand &operand : op->getOpOperands()) + newOp->setOperand( + operand.getOperandNumber(), + getValueAs(operand.get(), *inferSrcEncoding(op, encoding))); + for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) { + auto origType = op->getResult(i).getType().dyn_cast(); + if (!origType) + continue; + auto newType = RankedTensorType::get(origType.getShape(), + origType.getElementType(), encoding); + newOp->getResult(i).setType(newType); + } + return newOp; +} - pushConversionForward(cvt, cvtSlices, sharedInfo, rewriter); - return success(); +Operation *LayoutPropagation::rewriteForOp(scf::ForOp forOp) { + SmallVector operands; + OpBuilder rewriter(forOp); + for (auto [operand, result] : + llvm::zip(forOp.getInitArgs(), forOp.getResults())) { + Value convertedOperand = operand; + if (layouts.count(result)) + convertedOperand = + getValueAs(operand, *layouts[result].encodings.begin()); + operands.push_back(convertedOperand); + } + auto newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), operands); + + newForOp.getBody()->getOperations().splice( + newForOp.getBody()->getOperations().begin(), + forOp.getBody()->getOperations()); + + for (auto [oldResult, newResult] : + llvm::zip(forOp.getResults(), newForOp.getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); } -}; -// Layout conversions are expensive. They require going through -// shared memory, which is orders of magnitude slower than -// other non-i/o operations in the dialect. -// It therefore makes sense to remove them whenever possible, -// even if it means rematerializing all values whose definitions -// are reachable from it without passing through any memory operation. -class RematerializeBackward : public mlir::RewritePattern { - PatternSharedInfo &sharedInfo; + for (auto [oldArg, newArg] : llvm::zip(forOp.getBody()->getArguments(), + newForOp.getBody()->getArguments())) { + if (oldArg.getType() == newArg.getType()) { + oldArg.replaceAllUsesWith(newArg); + continue; + } + map(oldArg, newArg); + } + return newForOp.getOperation(); +} -public: - explicit RematerializeBackward(mlir::MLIRContext *context, PatternSharedInfo &sharedInfo) - : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), - 3, context), sharedInfo(sharedInfo) {} +Operation *LayoutPropagation::rewriteWhileOp(scf::WhileOp whileOp) { + SmallVector operands; + SmallVector returnTypes; + OpBuilder rewriter(whileOp); + for (auto [operand, arg] : + llvm::zip(whileOp->getOperands(), whileOp.getBeforeArguments())) { + Value convertedOperand = operand; + if (layouts.count(arg)) + convertedOperand = getValueAs(operand, *layouts[arg].encodings.begin()); + operands.push_back(convertedOperand); + } + for (Value ret : whileOp.getResults()) { + auto it = layouts.find(ret); + if (it == layouts.end()) { + returnTypes.push_back(ret.getType()); + continue; + } + auto origType = ret.getType().dyn_cast(); + auto newType = + RankedTensorType::get(origType.getShape(), origType.getElementType(), + it->second.encodings[0]); + returnTypes.push_back(newType); + } + auto newWhileOp = + rewriter.create(whileOp.getLoc(), returnTypes, operands); + SmallVector argsTypesBefore; + for (Value operand : operands) + argsTypesBefore.push_back(operand.getType()); + SmallVector bbArgLocsBefore(argsTypesBefore.size(), + whileOp.getLoc()); + SmallVector bbArgLocsAfter(returnTypes.size(), whileOp.getLoc()); + rewriter.createBlock(&newWhileOp.getBefore(), {}, argsTypesBefore, + bbArgLocsBefore); + rewriter.createBlock(&newWhileOp.getAfter(), {}, returnTypes, bbArgLocsAfter); + + for (int i = 0; i < whileOp.getNumRegions(); ++i) { + newWhileOp->getRegion(i).front().getOperations().splice( + newWhileOp->getRegion(i).front().getOperations().begin(), + whileOp->getRegion(i).front().getOperations()); + } - mlir::LogicalResult - matchAndRewrite(mlir::Operation *cvt, - mlir::PatternRewriter &rewriter) const override { - if (!llvm::isa(cvt)) - return mlir::failure(); + auto remapArg = [&](Value oldVal, Value newVal) { + if (oldVal.getType() == newVal.getType()) + oldVal.replaceAllUsesWith(newVal); + else + map(oldVal, newVal); + }; + for (auto [oldResult, newResult] : + llvm::zip(whileOp.getResults(), newWhileOp.getResults())) + remapArg(oldResult, newResult); + for (auto [oldArg, newArg] : + llvm::zip(whileOp.getBeforeArguments(), newWhileOp.getBeforeArguments())) + remapArg(oldArg, newArg); + for (auto [oldArg, newArg] : + llvm::zip(whileOp.getAfterArguments(), newWhileOp.getAfterArguments())) + remapArg(oldArg, newArg); + return newWhileOp.getOperation(); +} - auto it = sharedInfo.cvtsPushedForwardMap.find(cvt); - if (it != sharedInfo.cvtsPushedForwardMap.end() && - it->second == cvt->getOperand(0).getDefiningOp()) - return mlir::failure(); +Operation *LayoutPropagation::rewriteIfOp(scf::IfOp ifOp) { + SmallVector operands; + OpBuilder rewriter(ifOp); + SmallVector newResultTypes(ifOp->getResultTypes()); + for (unsigned i = 0, e = ifOp->getNumResults(); i < e; ++i) { + auto it = layouts.find(ifOp->getResult(i)); + if (it == layouts.end()) + continue; + auto origType = ifOp->getResult(i).getType().cast(); + Attribute encoding = *(it->second.encodings.begin()); + newResultTypes[i] = RankedTensorType::get( + origType.getShape(), origType.getElementType(), encoding); + } + auto newIfOp = rewriter.create(ifOp.getLoc(), newResultTypes, + ifOp.getCondition(), true, true); + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + for (auto [oldResult, newResult] : + llvm::zip(ifOp.getResults(), newIfOp.getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + return newIfOp.getOperation(); +} - // we don't touch block arguments - Operation *op = cvt->getOperand(0).getDefiningOp(); - if (!op) - return mlir::failure(); - // we don't want to rematerialize any conversion to/from shared - if (triton::gpu::isSharedEncoding(cvt->getResults()[0]) || - triton::gpu::isSharedEncoding(cvt->getOperand(0))) - return mlir::failure(); - // we don't handle conversions to DotOperandEncodingAttr - // this is a heuristics to accommodate fused attention - auto targetType = cvt->getResultTypes()[0].cast(); - if (targetType.getEncoding().isa()) - return mlir::failure(); - // DFS - SetVector processed; - SetVector layout; - llvm::MapVector toConvert; - if (simulateBackwardRematerialization(cvt, processed, layout, toConvert, - targetType.getEncoding()) > 0) - return mlir::failure(); +void LayoutPropagation::rewriteYieldOp(scf::YieldOp yieldOp) { + Operation *parentOp = yieldOp->getParentOp(); + for (OpOperand &operand : yieldOp->getOpOperands()) { + Type yieldType = operand.get().getType(); + if (isa(parentOp)) + yieldType = parentOp->getResult(operand.getOperandNumber()).getType(); + if (auto whileOp = dyn_cast(parentOp)) + yieldType = + whileOp.getBeforeArguments()[operand.getOperandNumber()].getType(); + auto tensorType = yieldType.dyn_cast(); + if (!tensorType) + continue; + Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); + yieldOp->setOperand(operand.getOperandNumber(), newOperand); + } +} - IRMapping mapping; - rematerializeConversionChain(toConvert, rewriter, processed, mapping); - rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0))); +void LayoutPropagation::rewriteConditionOp(scf::ConditionOp conditionOp) { + scf::WhileOp whileOp = cast(conditionOp->getParentOp()); + for (unsigned i = 1; i < conditionOp->getNumOperands(); ++i) { + OpOperand &operand = conditionOp->getOpOperand(i); + Type argType = whileOp->getResult(operand.getOperandNumber() - 1).getType(); + auto tensorType = argType.dyn_cast(); + if (!tensorType) + continue; + Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); + conditionOp->setOperand(operand.getOperandNumber(), newOperand); + } +} - return mlir::success(); +void LayoutPropagation::rewriteReduceToScalar(Operation *reduceOp) { + OpBuilder rewriter(reduceOp); + Attribute srcEncoding; + // Since all the operands need to have the same encoding pick the first one + // and use it for all the operands. + for (Value operand : reduceOp->getOperands()) { + auto it = layouts.find(operand); + if (it != layouts.end()) { + srcEncoding = it->second.encodings[0]; + break; + } } -}; + if (!srcEncoding) + return; + for (OpOperand &operand : reduceOp->getOpOperands()) { + Value newOperand = getValueAs(operand.get(), srcEncoding); + reduceOp->setOperand(operand.getOperandNumber(), newOperand); + } +} -// ----------------------------------------------------------------------------- -// -// ----------------------------------------------------------------------------- +Operation *LayoutPropagation::rewriteOp(Operation *op) { + opToDelete.push_back(op); + if (auto forOp = dyn_cast(op)) + return rewriteForOp(forOp); + if (auto whileOp = dyn_cast(op)) + return rewriteWhileOp(whileOp); + if (auto ifOp = dyn_cast(op)) + return rewriteIfOp(ifOp); + OpBuilder rewriter(op); + Attribute encoding = *layouts[op->getResult(0)].encodings.begin(); + if (auto convertOp = dyn_cast(op)) { + Attribute srcEncoding = + convertOp.getOperand().getType().cast().getEncoding(); + auto it = layouts.find(convertOp.getOperand()); + if (it != layouts.end()) + srcEncoding = *(it->second.encodings.begin()); + Value src = getValueAs(convertOp.getOperand(), srcEncoding); + auto tensorType = op->getResult(0).getType().cast(); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + auto cvt = rewriter.create(op->getLoc(), + newType, src); + map(op->getResult(0), cvt.getResult()); + return cvt.getOperation(); + } + if (canFoldIntoConversion(op, encoding)) { + Operation *newOp = rewriter.clone(*op); + auto tensorType = op->getResult(0).getType().cast(); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + auto cvt = rewriter.create( + op->getLoc(), newType, newOp->getResult(0)); + map(op->getResult(0), cvt.getResult()); + return cvt.getOperation(); + } + if (op->hasTrait() || + op->hasTrait() || + isa( + op)) { + Operation *newOp = cloneElementwise(rewriter, op, encoding); + for (auto [oldResult, newResult] : + llvm::zip(op->getResults(), newOp->getResults())) + map(oldResult, newResult); + return newOp; + } + assert(0 && "unexpected op in rewrite"); + return nullptr; +} -class MoveConvertOutOfLoop : public mlir::RewritePattern { - PatternSharedInfo &sharedInfo; +static bool canBeRemat(Operation *op) { + if (isa(op)) + return !isExpensiveLoadOrStore(op); + if (isa(op)) + return false; + if (isa(op)) + return false; + + return true; +} -public: - explicit MoveConvertOutOfLoop(mlir::MLIRContext *context, - PatternSharedInfo &sharedInfo) - : mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context), - sharedInfo(sharedInfo) {} - - SmallVector - rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp, - size_t i, RankedTensorType newType, - triton::gpu::ConvertLayoutOp origConversion) const { - // Rewrite init argument - Type origType = forOp.getInitArgs()[i].getType(); - SmallVector newInitArgs = forOp.getInitArgs(); - newInitArgs[i] = rewriter.create( - newInitArgs[i].getLoc(), newType, newInitArgs[i]); - // Clone for loop - auto newForOp = rewriter.create( - forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newInitArgs); - newForOp->moveBefore(forOp); - rewriter.setInsertionPointToStart(newForOp.getBody()); - IRMapping mapping; - for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) - mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); - mapping.map(origConversion.getResult(), newForOp.getRegionIterArgs()[i]); - - mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); - for (Operation &op : forOp.getBody()->without_terminator()) { - if (&op == (Operation *)(&origConversion)) - continue; - Operation *newOp = rewriter.clone(op, mapping); +// Replace ForOp with a new ForOp with extra operands. The YieldOp is not +// updated and needs to be updated separatly for the loop to be correct. +static scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, + scf::ForOp loop, + ValueRange newIterOperands) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(loop); + + // Create a new loop before the existing one, with the extra operands. + rewriter.setInsertionPoint(loop); + auto operands = llvm::to_vector<4>(loop.getIterOperands()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + scf::ForOp newLoop = rewriter.create( + loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), + operands); + newLoop.getBody()->erase(); + + newLoop.getLoopBody().getBlocks().splice( + newLoop.getLoopBody().getBlocks().begin(), + loop.getLoopBody().getBlocks()); + for (Value operand : newIterOperands) + newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); + + for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( + loop.getNumResults()))) + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + return newLoop; +} + +static void rewriteSlice(SetVector &slice, + DenseMap &layout, + ConvertLayoutOp convertOp, IRMapping &mapping) { + SetVector opsToRewrite; + for (Value v : slice) { + if (v.getDefiningOp()) { + opsToRewrite.insert(v.getDefiningOp()); + } else { + opsToRewrite.insert(v.cast().getOwner()->getParentOp()); + // We also need to rewrite the yield op. + opsToRewrite.insert(v.cast().getOwner()->getTerminator()); } - // create yield, inserting conversions if necessary - auto yieldOp = forOp.getBody()->getTerminator(); - SmallVector newYieldArgs; - for (Value arg : yieldOp->getOperands()) - newYieldArgs.push_back(mapping.lookup(arg)); - if (newYieldArgs[i].getType() != newType) - newYieldArgs[i] = rewriter.create( - yieldOp->getLoc(), newType, newYieldArgs[i]); - rewriter.create(forOp.getLoc(), newYieldArgs); - - // replace - SmallVector newResults = newForOp->getResults(); - newResults[i] = rewriter.create( - newForOp.getLoc(), origType, newForOp->getResult(i)); - newResults[i].getDefiningOp()->moveAfter(newForOp); - return newResults; } - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto forOp = cast(op); - auto iterArgs = forOp.getRegionIterArgs(); - for (const auto &iterArg : llvm::enumerate(iterArgs)) { - // skip non-tensor types - if (!iterArg.value().getType().isa()) - continue; - SmallVector cvts; - if (canMoveOutOfLoop(iterArg.value(), cvts).failed()) - continue; - // check - for (auto *op : cvts) { - auto cvt = dyn_cast(op); - auto it = sharedInfo.cvtsPushedForwardMap.find(cvt); - if (it != sharedInfo.cvtsPushedForwardMap.end()) - return mlir::failure(); - auto targetType = op->getResultTypes()[0].cast(); - auto newFor = rematerializeForLoop(rewriter, forOp, iterArg.index(), - targetType, cvt); - rewriter.replaceOp(forOp, newFor); - return success(); + opsToRewrite = multiRootTopologicalSort(opsToRewrite); + + SmallVector deadLoops; + OpBuilder builder(slice.begin()->getContext()); + for (Operation *op : opsToRewrite) { + if (auto forOp = dyn_cast(op)) { + // Keep a mapping of the operands index to the new operands index. + SmallVector> argMapping; + SmallVector newOperands; + for (auto arg : forOp.getRegionIterArgs()) { + if (slice.count(arg)) { + OpOperand &initVal = forOp.getOpOperandForRegionIterArg(arg); + argMapping.push_back( + std::make_pair(*forOp.getIterArgNumberForOpOperand(initVal), + forOp.getNumIterOperands() + newOperands.size())); + newOperands.push_back(mapping.lookup(initVal.get())); + } + } + // Create a new for loop with the new operands. + scf::ForOp newForOp = + replaceForOpWithNewSignature(builder, forOp, newOperands); + deadLoops.push_back(forOp.getOperation()); + Block &loopBody = *newForOp.getBody(); + for (auto m : argMapping) { + mapping.map(newForOp.getResult(m.first), newForOp.getResult(m.second)); + int numIndVars = newForOp.getNumInductionVars(); + mapping.map(loopBody.getArgument(m.first + numIndVars), + loopBody.getArgument(m.second + numIndVars)); + } + continue; + } + builder.setInsertionPoint(op); + if (auto yieldOp = dyn_cast(op)) { + auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); + for (Value operand : yieldOp.getOperands()) { + if (slice.count(operand) == 0) + continue; + yieldOperands.push_back(mapping.lookup(operand)); } + builder.create(op->getLoc(), yieldOperands); + op->erase(); + continue; + } + if (isa(op)) { + Operation *newOp = builder.clone(*op); + auto tensorType = op->getResult(0).getType().cast(); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), + layout[op->getResult(0)]); + auto cvt = builder.create( + op->getLoc(), newType, newOp->getResult(0)); + mapping.map(op->getResult(0), cvt.getResult()); + continue; + } + Operation *newOp = builder.clone(*op, mapping); + for (auto [old, newV] : llvm::zip(op->getResults(), newOp->getResults())) { + auto it = layout.find(old); + if (it == layout.end()) + continue; + auto newType = RankedTensorType::get( + old.getType().cast().getShape(), + old.getType().cast().getElementType(), it->second); + newV.setType(newType); } + } + convertOp.replaceAllUsesWith(mapping.lookup(convertOp.getOperand())); + convertOp.erase(); + for (Operation *op : deadLoops) + op->erase(); +} + +static void rewriteSlice(SetVector &slice, + DenseMap &layout, + ConvertLayoutOp convertOp) { + IRMapping mapping; + rewriteSlice(slice, layout, convertOp, mapping); +} + +static LogicalResult getRematerializableSlice( + Value root, Attribute rootEncoding, SetVector &slice, + DenseMap &layout, + std::function stopPropagation = nullptr) { + LogicalResult result = getConvertBackwardSlice(root, slice, rootEncoding, + layout, stopPropagation); + if (result.failed() || slice.empty()) return failure(); + + // Check if all the operations in the slice can be rematerialized. + for (Value v : slice) { + if (Operation *op = v.getDefiningOp()) { + if (!canBeRemat(op)) + return failure(); + } } -}; + return success(); +} -// -class ConvertDotConvert : public mlir::RewritePattern { -public: - ConvertDotConvert(mlir::MLIRContext *context) - : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), - 1, context) {} +static void backwardRematerialization(ConvertLayoutOp convertOp) { + // we don't want to rematerialize any conversion to/from shared + if (triton::gpu::isSharedEncoding(convertOp.getResult()) || + triton::gpu::isSharedEncoding(convertOp.getOperand())) + return; + // we don't handle conversions to DotOperandEncodingAttr + // this is a heuristics to accommodate fused attention + auto targetType = convertOp->getResultTypes()[0].cast(); + if (targetType.getEncoding().isa()) + return; - LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto dstOp = cast(op); - auto dotOp = - dyn_cast_or_null(dstOp.getSrc().getDefiningOp()); - if (!dotOp) - return mlir::failure(); - if (std::distance(dstOp->user_begin(), dstOp->user_end()) != 1 || - std::distance(dotOp->user_begin(), dotOp->user_end()) != 1) - return mlir::failure(); - auto cvtOp = dyn_cast_or_null( - dotOp.getOperand(2).getDefiningOp()); - if (!cvtOp) - return mlir::failure(); - auto loadOp = - dyn_cast_or_null(cvtOp.getSrc().getDefiningOp()); - if (!loadOp) - return mlir::failure(); - auto dstTy = dstOp.getResult().getType().cast(); - auto srcTy = cvtOp.getOperand().getType().cast(); - if (dstTy != srcTy) - return mlir::failure(); + // 1. Take a backward slice of all the tensor dependencies that can be + // rematerialized. + SetVector slice; + DenseMap layout; + LogicalResult result = getRematerializableSlice( + convertOp.getOperand(), targetType.getEncoding(), slice, layout); + if (result.failed()) + return; - // TODO: int tensor cores - auto out_dtype = dstTy.getElementType().cast(); - APFloat value(0.0f); - if (out_dtype.isBF16()) - value = APFloat(APFloat::IEEEhalf(), APInt(16, 0)); - else if (out_dtype.isF16()) - value = APFloat(APFloat::IEEEhalf(), APInt(16, 0)); - else if (out_dtype.isF32()) - value = APFloat(0.0f); - else - llvm_unreachable("unsupported data type"); + // 2. Rewrite the slice. + rewriteSlice(slice, layout, convertOp); +} - auto _0f = - rewriter.create(op->getLoc(), value, out_dtype); - auto _0 = rewriter.create( - op->getLoc(), dotOp.getResult().getType(), _0f); - auto newDot = rewriter.create( - op->getLoc(), dotOp.getResult().getType(), dotOp.getOperand(0), - dotOp.getOperand(1), _0, dotOp.getAllowTF32()); - auto newCvt = rewriter.create( - op->getLoc(), dstTy, newDot.getResult()); - rewriter.replaceOpWithNewOp(op, newCvt, cvtOp.getOperand()); - return mlir::success(); +// For convert left we try to hoist them above type extension to reduce the cost +// of the convert. +static void hoistConvertOnTopOfExt(ConvertLayoutOp convertOp) { + // we don't want to rematerialize any conversion to/from shared + if (triton::gpu::isSharedEncoding(convertOp.getResult()) || + triton::gpu::isSharedEncoding(convertOp.getOperand())) + return; + // we don't handle conversions to DotOperandEncodingAttr + // this is a heuristics to accommodate fused attention + auto targetType = convertOp->getResultTypes()[0].cast(); + if (targetType.getEncoding().isa()) + return; + + auto isExtOp = [](Operation *op) { + return isa(op); + }; + // 1. Take a backward slice of all the tensor dependencies. + SetVector slice; + DenseMap layout; + LogicalResult result = getRematerializableSlice( + convertOp.getOperand(), targetType.getEncoding(), slice, layout, isExtOp); + if (result.failed()) + return; + + Operation *extOp = nullptr; + unsigned sliceSize = slice.size(); + for (unsigned i = 0; i < sliceSize; i++) { + Value v = slice[i]; + Operation *op = v.getDefiningOp(); + if (!op) + continue; + if (isExtOp(op)) { + SetVector tempSlice; + DenseMap tempLayout; + LogicalResult result = getRematerializableSlice( + op->getOperand(0), layout[v], tempSlice, tempLayout); + // If we can rematerialize the rest of the ext slice we can ignore this + // ext as it won't need a convert. + if (result.succeeded()) { + slice.insert(tempSlice.begin(), tempSlice.end()); + layout.insert(tempLayout.begin(), tempLayout.end()); + continue; + } + // Only apply it if there is a single ext op otherwise we would have to + // duplicate the convert. + if (extOp != nullptr) + return; + extOp = op; + } } -}; -} // namespace + if (extOp == nullptr) + return; + // Move the convert before the ext op and rewrite the slice. + OpBuilder builder(extOp); + auto tensorType = extOp->getOperand(0).getType().cast(); + auto newType = + RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), + layout[extOp->getResult(0)]); + auto newConvertOp = builder.create( + convertOp.getLoc(), newType, extOp->getOperand(0)); + IRMapping mapping; + mapping.map(extOp->getOperand(0), newConvertOp.getResult()); + // 3. Rewrite the slice. + rewriteSlice(slice, layout, convertOp, mapping); +} + +static void backwardRematerialization(ModuleOp module) { + SmallVector convertOps; + module.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + backwardRematerialization(convertOp); + } +} + +static void hoistConvert(ModuleOp module) { + SmallVector convertOps; + module.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + hoistConvertOnTopOfExt(convertOp); + } +} #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" @@ -653,23 +1098,45 @@ class TritonGPURemoveLayoutConversionsPass MLIRContext *context = &getContext(); ModuleOp m = getOperation(); - mlir::RewritePatternSet patterns(context); - PatternSharedInfo sharedInfo; - - patterns.add(context); - patterns.add(context); - patterns.add(context, sharedInfo); - patterns.add(context, sharedInfo); - patterns.add(context, sharedInfo); - patterns.add(context); - patterns.add(context); - patterns.add(context); + // 1. Propagate layout forward starting from "anchor" ops. + m.walk([](triton::FuncOp funcOp) { + LayoutPropagation layoutPropagation(funcOp); + layoutPropagation.initAnchorLayout(); + layoutPropagation.propagateLayout(); + layoutPropagation.resolveConflicts(); + layoutPropagation.rewrite(); + }); + + mlir::RewritePatternSet cleanUpPatterns(context); + ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns, context); + if (mlir::applyPatternsAndFoldGreedily(m, std::move(cleanUpPatterns)) + .failed()) { + signalPassFailure(); + } - if (mlir::applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { + // 2. For convert ops left try to rematerialize the slice of producer + // operation to avoid having to convert. + backwardRematerialization(m); + // 3. For converts left try to hoist them above cast generating larger size + // types in order to reduce the cost of the convert op. + hoistConvert(m); + + mlir::RewritePatternSet decomposePatterns(context); + decomposePatterns.add(context); + decomposePatterns.add(context); + if (mlir::applyPatternsAndFoldGreedily(m, std::move(decomposePatterns)) + .failed()) { signalPassFailure(); } - if (fixupLoops(m).failed()) { + // 4. Apply clean up patterns to remove remove dead convert and dead code + // generated by the previous transformations. + mlir::RewritePatternSet cleanUpPatterns2(context); + populateForOpDeadArgumentElimination(cleanUpPatterns2); + scf::ForOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + if (mlir::applyPatternsAndFoldGreedily(m, std::move(cleanUpPatterns2)) + .failed()) { signalPassFailure(); } } diff --git a/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp b/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp index 57af32eb30bc..3998397a8e67 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp @@ -17,6 +17,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" @@ -47,6 +48,12 @@ class TritonGPUReorderInstructionsPass // Sink conversions into loops when they will increase // register pressure DenseMap opToMove; + auto moveAfter = [](Operation *lhs, Operation *rhs) { + auto lhsId = getWSRoleId(lhs); + auto rhsId = getWSRoleId(rhs); + if (lhsId == rhsId) + lhs->moveAfter(rhs); + }; m.walk([&](triton::gpu::ConvertLayoutOp op) { if (!willIncreaseRegisterPressure(op)) return; @@ -81,7 +88,7 @@ class TritonGPUReorderInstructionsPass Operation *argOp = op.getOperand().getDefiningOp(); if (!argOp) return; - op->moveAfter(argOp); + moveAfter(op, argOp); }); // Move transpositions just after their definition opToMove.clear(); @@ -89,7 +96,7 @@ class TritonGPUReorderInstructionsPass Operation *argOp = op.getOperand().getDefiningOp(); if (!argOp) return; - op->moveAfter(argOp); + moveAfter(op, argOp); }); // Move `dot` operand so that conversions to opIdx=1 happens after // conversions to opIdx=0 @@ -122,7 +129,7 @@ class TritonGPUReorderInstructionsPass // after the conversion to OpIdx=0. if (!dom.dominates(op.getOperation(), AOp.getOperation())) return; - op->moveAfter(AOp); + moveAfter(op, AOp); }); return; } diff --git a/lib/Dialect/TritonGPU/Transforms/StreamPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/StreamPipeline.cpp index a99d283b692c..22369cb9f0fa 100644 --- a/lib/Dialect/TritonGPU/Transforms/StreamPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/StreamPipeline.cpp @@ -52,7 +52,7 @@ class LoopPipeliner { scf::ForOp pplForOp; /// Loads to be pipelined - SetVector validLoads; + SetVector validLoads; /// The value that each load will be mapped to (after layout conversion) DenseMap convertMapping; /// load => buffer @@ -89,6 +89,8 @@ class LoopPipeliner { /// Dependency ops by program order SmallVector orderedDeps; + SetVector currentDeps; + /// block arguments that loads depend on SetVector depArgs; @@ -98,26 +100,28 @@ class LoopPipeliner { /// operations that loads depend on SetVector depOps; - /// Collect all pipelinable ops - LogicalResult collectOps(SetVector &ops); - /// Collect values that `v` depends on and are defined inside the loop - void collectValueDep(Value v, int stage, SetVector &opDeps); + void collectValueDep(Value v, int stage, SetVector &deps, + SetVector &args); /// Collect all op dependencies void collectDeps(SetVector &ops, - MapVector> &opDeps); + MapVector> &opDeps); - /// Check if none of the ops has valid uses - LogicalResult checkOpUses(SetVector &ops); + void collectDepChain(Operation *op, SetVector &ops); + + /// Check if none of the for-ops has valid uses + LogicalResult checkOpUses(); /// Check if ops have dependencies that are not pipelinable - void checkOpDeps(SetVector &ops); + LogicalResult checkOpDeps(); void createBufferTypes(); void createOrderedDeps(); + void createCurrentDeps(); + /// Return the stage at which `v` is defined prior to `stage` int getValueDefStage(Value v, int stage); @@ -149,6 +153,8 @@ class LoopPipeliner { void cloneCurrentBody(OpBuilder &builder); void storeNextBuffer(OpBuilder &builder); + bool isLoadChain(Operation *op) const; + /// Assemble `pplForOp`'s yield op void finalizeYield(OpBuilder &builder); @@ -173,158 +179,142 @@ class LoopPipeliner { friend struct PipelinePass; }; -/// Collect loads to pipeline. Return success if we can pipeline this loop -LogicalResult LoopPipeliner::collectOps(SetVector &ops) { - ModuleOp moduleOp = forOp->getParentOfType(); - ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); - - // We cannot use forOp.walk(...) here because we only want to visit the - // operations in the loop body block. Nested blocks are handled separately. - for (Operation &op : forOp) - if (auto loadOp = dyn_cast(&op)) { - // pipeline all loads - ops.insert(loadOp); - } - - if (ops.empty()) - return failure(); - else - return success(); -} - void LoopPipeliner::collectValueDep(Value v, int stage, - SetVector &deps) { - // Loop-invariant value, skip - if (v.getParentRegion() != &forOp.getLoopBody()) - return; - - if (deps.contains(v)) - return; - + SetVector &deps, + SetVector &args) { // Since we only need to peel the loop numStages-1 times, don't worry // about depends that are too far away if (stage < 0) return; - if (auto arg = v.dyn_cast()) { + // Loop-invariant value, skip + if (v.getParentRegion() != &forOp.getLoopBody()) + return; + + if (Operation *op = v.getDefiningOp()) { + if (!deps.contains(op)) { + deps.insert(op); + for (Value opr : op->getOperands()) + collectValueDep(opr, stage, deps, args); + } + } else if (auto arg = v.dyn_cast()) { if (arg.getArgNumber() > 0) { - deps.insert(v); + args.insert(arg); collectValueDep(yieldOp->getOperand(arg.getArgNumber() - 1), stage - 1, - deps); + deps, args); } - } else { // value - deps.insert(v); - for (Value op : v.getDefiningOp()->getOperands()) - collectValueDep(op, stage, deps); } } void LoopPipeliner::collectDeps( SetVector &ops, - MapVector> &valueDeps) { + MapVector> &valueDeps) { for (auto op : ops) { for (Value v : op->getOperands()) { - SetVector deps; - collectValueDep(v, numStages - 1, deps); + SetVector deps; + SetVector args; + collectValueDep(v, numStages - 1, deps, args); valueDeps[op] = deps; } } } -LogicalResult LoopPipeliner::checkOpUses(SetVector &ops) { - DenseSet invalidOps; +LogicalResult LoopPipeliner::checkOpUses() { + SetVector ops; + // We cannot use forOp.walk(...) here because we only want to visit the + // operations in the loop body block. Nested blocks are handled separately. + for (Operation &op : forOp) { + if (auto loadOp = dyn_cast(&op)) + ops.insert(&op); + } + // Collect all ops' dependencies - MapVector> opDeps; + MapVector> opDeps; collectDeps(ops, opDeps); for (Operation *op : ops) { - if (auto loadOp = dyn_cast(op)) { - // Don't pipeline valid loads that depend on other valid loads - // (Because if a valid load depends on another valid load, this load needs - // to wait on the other load in the prologue, which is against the point - // of the pipeline pass) - bool isCandidate = true; - for (Operation *other : ops) - if (isa(other)) - if (opDeps[op].contains(other->getResult(0))) { - isCandidate = false; - break; - } - // We only pipeline loads that have one covert_layout (to dot_op) use - // TODO: lift this constraint in the future - if (isCandidate && loadOp.getResult().hasOneUse()) { - isCandidate = false; - Operation *use = *loadOp.getResult().getUsers().begin(); - - // Advance to the first conversion as long as the use resides in shared - // memory and it has a single use itself - while (use) { - if (use->getNumResults() != 1 || !use->getResult(0).hasOneUse()) - break; - auto tensorType = - use->getResult(0).getType().dyn_cast(); - if (!tensorType.getEncoding().isa()) - break; - use = *use->getResult(0).getUsers().begin(); + auto loadOp = dyn_cast(op); + // Don't pipeline valid loads that depend on other valid loads + // (Because if a valid load depends on another valid load, this load needs + // to wait on the other load in the prologue, which is against the point + // of the pipeline pass) + bool isCandidate = true; + for (Operation *other : ops) + if (isa(other)) + if (opDeps[op].contains(other)) { + isCandidate = false; + break; } + // We only pipeline loads that have one covert_layout (to dot_op) use + // TODO: lift this constraint in the future + if (isCandidate && loadOp.getResult().hasOneUse()) { + isCandidate = false; + Operation *use = *loadOp.getResult().getUsers().begin(); + + // Advance to the first conversion as long as the use resides in shared + // memory and it has a single use itself + while (use) { + if (use->getNumResults() != 1 || !use->getResult(0).hasOneUse()) + break; + auto tensorType = + use->getResult(0).getType().dyn_cast(); + if (!tensorType || !tensorType.getEncoding().isa()) + break; + use = *use->getResult(0).getUsers().begin(); + } - if (auto convertLayout = llvm::dyn_cast(use)) - if (auto tensorType = convertLayout.getResult() - .getType() - .dyn_cast()) - if (auto dotOpEnc = tensorType.getEncoding() - .dyn_cast()) { - isCandidate = true; - convertMapping[loadOp] = convertLayout; - } - } else - isCandidate = false; - - if (!isCandidate) - invalidOps.insert(loadOp); - else - validLoads.insert(loadOp); - } - } + // TODO: handle fp_to_fp conversions in between + if (auto convertLayout = llvm::dyn_cast(use)) + if (auto tensorType = convertLayout.getResult() + .getType() + .dyn_cast()) + if (auto dotOpEnc = tensorType.getEncoding() + .dyn_cast()) { + isCandidate = true; + convertMapping[loadOp] = convertLayout; + } + } else + isCandidate = false; - for (Operation *op : invalidOps) - ops.remove(op); + if (isCandidate) + validLoads.insert(op); + } - if (ops.empty()) - return failure(); - else - return success(); + return validLoads.empty() ? failure() : success(); } -void LoopPipeliner::checkOpDeps(SetVector &ops) { +LogicalResult LoopPipeliner::checkOpDeps() { /// arg => source operand defined stages DenseMap> immediateArgStages; SetVector nonImmediateDepArgs; SetVector nonImmediateOps; - for (Operation *op : ops) { + for (Operation *op : validLoads) { for (Value v : op->getOperands()) { - SetVector deps; - collectValueDep(v, numStages - 1, deps); + SetVector deps; + SetVector args; + collectValueDep(v, numStages - 1, deps, args); int defStage = getValueDefStage(v, numStages - 1); - assert(defStage >= 0 && - "newLoopArgs has null args without a define op. Consider either " - "rewrite the loop to reduce cross iteration dependencies or " - "increase the num_stages value."); - for (auto dep : deps) { - auto immediate = deps.front().isa(); - if (auto arg = dyn_cast(dep)) { - depArgs.insert(arg); - if (immediate) - immediateArgStages[arg].insert(defStage); - else - nonImmediateDepArgs.insert(arg); - } else { - depOps.insert(dep.getDefiningOp()); - if (immediate) - immediateOpStages[dep.getDefiningOp()].insert(defStage); - else - nonImmediateOps.insert(dep.getDefiningOp()); - } + if (defStage < 0) { + // assert(defStage >= 0 && + // "newLoopArgs has null args without a define op. Consider either " + // "rewrite the loop to reduce cross iteration dependencies or " + // "increase the num_stages value."); + return failure(); + } + bool immediate = args.size() > 0; + for (auto *dep : deps) { + depOps.insert(dep); + if (immediate) + immediateOpStages[dep].insert(defStage); + else + nonImmediateOps.insert(dep); + } + for (auto arg : args) { + depArgs.insert(arg); + if (immediate) + immediateArgStages[arg].insert(defStage); + else + nonImmediateDepArgs.insert(arg); } } } @@ -356,6 +346,7 @@ void LoopPipeliner::checkOpDeps(SetVector &ops) { "removing pre/post load instructions dependency on this " "operation."); } + return success(); } // helpers @@ -410,27 +401,46 @@ void LoopPipeliner::createBufferTypes() { ty.getShape().end()); Type eType = ty.getElementType(); auto blockedEnc = ty.getEncoding().cast(); + auto CTALayout = ttg::getCTALayout(ty.getEncoding()); // unsigned bitWidth = dotOpEnc.getMMAv2kWidth() // ? 32 / dotOpEnc.getMMAv2kWidth() // : ty.getElementType().getIntOrFloatBitWidth(); auto sharedEnc = ttg::SharedEncodingAttr::get(ty.getContext(), dotOpEnc, ty.getShape(), - ttg::getOrder(ty.getEncoding()), eType); + ttg::getOrder(ty.getEncoding()), CTALayout, eType); loadsBufferType[loadOp] = RankedTensorType::get(bufferShape, eType, sharedEnc); } } void LoopPipeliner::createOrderedDeps() { - for (Operation &op : forOp.getLoopBody().front()) { // @@@ front necessary? + for (Operation &op : forOp.getBody()->without_terminator()) { if (depOps.contains(&op)) orderedDeps.push_back(&op); - else if (op.getNumResults() > 0 && validLoads.contains(op.getResult(0))) + else if (op.getNumResults() > 0 && validLoads.contains(&op)) orderedDeps.push_back(&op); } assert(depOps.size() + validLoads.size() == orderedDeps.size() && "depOps contains invalid values"); } +void LoopPipeliner::collectDepChain(Operation *op, SetVector &ops) { + if (op->getNumResults() == 1 && validLoads.contains(op)) + return; + if (!ops.contains(op)) { + ops.insert(op); + for (Value opr : op->getOperands()) + if (Operation *oprOp = opr.getDefiningOp()) + collectDepChain(oprOp, ops); + } +} + +void LoopPipeliner::createCurrentDeps() { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!llvm::is_contained(orderedDeps, &op)) + collectDepChain(&op, currentDeps); + } +} + int LoopPipeliner::getValueDefStage(Value v, int stage) { if (stage < 0) return -1; @@ -444,21 +454,18 @@ int LoopPipeliner::getValueDefStage(Value v, int stage) { } LogicalResult LoopPipeliner::initialize() { - // All ops that maybe pipelined - SetVector ops; - - if (collectOps(ops).failed()) + if (checkOpUses().failed()) return failure(); - if (checkOpUses(ops).failed()) + if (checkOpDeps().failed()) return failure(); - checkOpDeps(ops); - createBufferTypes(); createOrderedDeps(); + createCurrentDeps(); + return success(); } @@ -490,6 +497,20 @@ Value LoopPipeliner::getLoadMask(triton::LoadOp loadOp, Value mappedMask, return mappedMask; } +bool LoopPipeliner::isLoadChain(Operation *op) const { + if (auto cvtOp = dyn_cast(op)) { + Value loadVal = cvtOp.getSrc(); + if (auto f2fOp = dyn_cast(op)) + loadVal = f2fOp.getFrom(); + if (validLoads.contains(loadVal.getDefiningOp())) { + auto cvtDstTy = cvtOp.getResult().getType().cast(); + if (cvtDstTy.getEncoding().isa()) + return true; + } + } + return false; +} + void LoopPipeliner::emitPrologue() { /// forOp block args => forOp operands /// forOp iterator => lower bound @@ -508,7 +529,7 @@ void LoopPipeliner::emitPrologue() { // Emit Iteration 0 loads, etc for (Operation *op : orderedDeps) { Operation *newOp = nullptr; - if (validLoads.contains(op->getResult(0))) { + if (validLoads.contains(op)) { auto loadOp = cast(op); // Load from global -> regs auto newLoadOp = cloneWithInferType(builder, op, prologueMap); @@ -532,16 +553,16 @@ void LoopPipeliner::emitPrologue() { void LoopPipeliner::emitEpilogue(DenseMap &newResults) { if (!peelLastIter) return; - OpBuilder builder(forOp); - builder.setInsertionPointAfter(forOp); + OpBuilder builder(pplForOp); + builder.setInsertionPointAfter(pplForOp); IRMapping epilogueMap; // Map 'for' iteration args to pipelined-for results auto args = forOp.getRegionIterArgs(); for (uint32_t i = 0; i < args.size(); ++i) epilogueMap.map(args[i], pplForOp.getResult(i)); - for (uint32_t i = 0; i < validLoads.size(); ++i) - epilogueMap.map(validLoads[i], pplForOp.getResult(bufferIdx + i)); + for (auto load : llvm::enumerate(validLoads)) + epilogueMap.map(load.value()->getResult(0), pplForOp.getResult(bufferIdx + load.index())); // Map IV to original upper bound (ie. last iteration) epilogueMap.map(forOp.getInductionVar(), forOp.getUpperBound()); @@ -549,16 +570,11 @@ void LoopPipeliner::emitEpilogue(DenseMap &newResults) { // Clone the loop body after the new ForOp // , replace original args with results of the new ForOp. for (Operation &op : forOp.getBody()->without_terminator()) { - if (!llvm::is_contained(orderedDeps, &op)) { + if (currentDeps.contains(&op)) { Operation *newOp = nullptr; - auto cvtOp = dyn_cast(op); - if (cvtOp && validLoads.contains(cvtOp.getSrc())) { - auto cvtDstTy = cvtOp.getResult().getType().cast(); - if (cvtDstTy.getEncoding().isa()) { - newOp = builder.clone(op, epilogueMap); - } - } - if (newOp == nullptr) + if (isLoadChain(&op)) + newOp = builder.clone(op, epilogueMap); + else newOp = cloneWithInferType(builder, &op, epilogueMap); // substitute for these results for the results of the new for loop for (const auto &pair : llvm::zip(op.getResults(), newOp->getResults())) { @@ -588,8 +604,8 @@ SmallVector LoopPipeliner::collectNewLoopArgs() { // Shared mem locations from iteration 0 bufferIdx = newLoopArgs.size(); - for (auto loadOp : validLoads) - newLoopArgs.push_back(loadsBuffer[loadOp]); + for (auto *loadOp : validLoads) + newLoopArgs.push_back(loadsBuffer[loadOp->getResult(0)]); // Loop carried vals depArgsBeginIdx = newLoopArgs.size(); @@ -620,8 +636,8 @@ scf::ForOp LoopPipeliner::cloneForOp(ArrayRef newLoopArgs, for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) curMapping.map(arg.value(), pplForOp.getRegionIterArgs()[arg.index()]); uint32_t bufIdx = bufferIdx; - for (auto loadOp : validLoads) - curMapping.map(loadOp, pplForOp.getRegionIterArgs()[bufIdx++]); + for (auto *loadOp : validLoads) + curMapping.map(loadOp->getResult(0), pplForOp.getRegionIterArgs()[bufIdx++]); curMapping.map(forOp.getInductionVar(), pplForOp.getInductionVar()); nextMapping = curMapping; @@ -661,7 +677,7 @@ void LoopPipeliner::prefetchNextBuffer(OpBuilder &builder) { // Emit prefetch loads of next buffer before compute of current buffer for (Operation *op : orderedDeps) { Operation *nextOp = nullptr; - if (validLoads.contains(op->getResult(0))) { + if (validLoads.contains(op)) { // Update loading mask auto loadOp = llvm::cast(op); auto mask = loadOp.getMask(); @@ -700,21 +716,12 @@ void LoopPipeliner::cloneCurrentBody(OpBuilder &builder) { auto loc = forOp.getLoc(); // only add instructions that are not part of the restructuring for (Operation &op : forOp.getBody()->without_terminator()) { - if (!llvm::is_contained(orderedDeps, &op)) { + if (currentDeps.contains(&op)) { Operation *newOp = nullptr; - auto cvtOp = dyn_cast(op); - if (cvtOp && validLoads.contains(cvtOp.getSrc())) { - auto cvtDstTy = cvtOp.getResult().getType().cast(); - if (cvtDstTy.getEncoding().isa()) - newOp = builder.clone(op, curMapping); - } - if (newOp == nullptr) + if (isLoadChain(&op)) + newOp = builder.clone(op, curMapping); + else newOp = cloneWithInferType(builder, &op, curMapping); - } else { - // hack for yield operands - if (auto ttadd = dyn_cast(op)) { - curMapping.map(ttadd.getResult(), curMapping.lookup(ttadd.getPtr())); - } } } } @@ -722,7 +729,7 @@ void LoopPipeliner::cloneCurrentBody(OpBuilder &builder) { void LoopPipeliner::storeNextBuffer(OpBuilder &builder) { // Store the next buffer at the end of the loop body for the next iteration for (Operation *op : orderedDeps) { - if (!validLoads.contains(op->getResult(0))) { + if (!validLoads.contains(op)) { if (immediateOpStages[op].contains(numStages - 2)) { Operation *nextOp = builder.clone(*op, nextMapping); if (auto loadOp = dyn_cast(op)) { @@ -740,12 +747,12 @@ void LoopPipeliner::storeNextBuffer(OpBuilder &builder) { } // PL loads -> store next to shared - for (auto loadOp : validLoads) { - Value loadVal = nextMapping.lookup(loadOp); + for (auto *loadOp : validLoads) { + Value loadVal = nextMapping.lookup(loadOp->getResult(0)); // then store regs -> shared Value storeBuf = pplForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()]; auto cvt = builder.create( - loadOp.getLoc(), storeBuf.getType(), loadVal); + loadOp->getLoc(), storeBuf.getType(), loadVal); nextBuffers.push_back(cvt); } @@ -758,8 +765,12 @@ void LoopPipeliner::storeNextBuffer(OpBuilder &builder) { void LoopPipeliner::finalizeYield(OpBuilder &builder) { SmallVector yieldValues; - for (Value v : yieldOp->getOperands()) - yieldValues.push_back(curMapping.lookup(v)); + for (const auto &opr : llvm::enumerate(yieldOp->getOperands())) { + if (curMapping.contains(opr.value())) + yieldValues.push_back(curMapping.lookup(opr.value())); + else + yieldValues.push_back(pplForOp.getRegionIterArgs()[opr.index()]); + } for (Value nextBuffer : nextBuffers) yieldValues.push_back(nextBuffer); diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index 6af4dbdf96e1..c345f1b87354 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -12,9 +12,13 @@ using namespace mlir::triton::gpu; // TypeConverter // TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, - int numWarps, int threadsPerWarp) - : context(context), numWarps(numWarps), threadsPerWarp(threadsPerWarp) { + int numWarps, int threadsPerWarp, + int numCTAs) + : context(context), numWarps(numWarps), threadsPerWarp(threadsPerWarp), + numCTAs(numCTAs) { addConversion([](Type type) { return type; }); + + // Add encoding for tensor addConversion([this](RankedTensorType tensorType) -> RankedTensorType { // types with encoding are already in the right format // TODO: check for layout encodings more specifically @@ -30,10 +34,24 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, llvm::SmallVector sizePerThread(rank, 1); Attribute encoding = triton::gpu::BlockedEncodingAttr::get( this->context, shape, sizePerThread, order, this->numWarps, - this->threadsPerWarp); + this->threadsPerWarp, this->numCTAs); return RankedTensorType::get(shape, tensorType.getElementType(), encoding); }); + // Add encoding for tensor pointer + addConversion([this](triton::PointerType ptrType) -> triton::PointerType { + // Check whether tensor pointer `tt.ptr>` + auto pointeeTensorType = + ptrType.getPointeeType().dyn_cast(); + if (pointeeTensorType == nullptr) + return ptrType; + + // Add layout into the tensor + auto convertedTensorType = convertType(pointeeTensorType); + return triton::PointerType::get(convertedTensorType, + ptrType.getAddressSpace()); + }); + // // Materializations // diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 95130a3f13fd..6e92fb2901ae 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -5,95 +5,303 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include namespace mlir { -namespace { +SmallVector mmaVersionToInstrShape(int version, + const ArrayRef &shape, + RankedTensorType type) { + if (version == 1) + return {16, 16}; + else if (version == 2) + return {16, 8}; + else if (version == 3) { + unsigned k = 256 / type.getElementTypeBitWidth(); + if (shape[0] % 64 != 0 || shape[1] % 8 != 0) { + assert(false && "type not supported"); + return {0, 0, 0}; + } + auto eltType = type.getElementType(); + SmallVector validN; -class FixupLoop : public mlir::RewritePattern { - -public: - explicit FixupLoop(mlir::MLIRContext *context) - : mlir::RewritePattern(scf::ForOp::getOperationName(), 2, context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto forOp = cast(op); - - // Rewrite init argument - SmallVector newInitArgs = forOp.getInitArgs(); - bool shouldRematerialize = false; - for (size_t i = 0; i < newInitArgs.size(); i++) { - if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType() || - newInitArgs[i].getType() != forOp.getResultTypes()[i]) { - shouldRematerialize = true; - break; + // MMAv3 with larger instruction shape is preferred. + if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FNUZ() || + eltType.isF16() || eltType.isBF16() || eltType.isF32()) { + validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, + 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, + 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); + } + + if (eltType.isInteger(8)) { + validN.assign({224, 208, 192, 176, 160, 144, 128, 112, 96, 80, 64, 48, 32, + 24, 16, 8}); + } + + for (auto n : validN) { + if (shape[1] % n == 0) { + return {16, n, k}; } } - if (!shouldRematerialize) - return failure(); - scf::ForOp newForOp = rewriter.create( - forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newInitArgs); - newForOp->moveBefore(forOp); - rewriter.setInsertionPointToStart(newForOp.getBody()); - IRMapping mapping; - for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) - mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); - mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); - - for (Operation &op : forOp.getBody()->getOperations()) { - rewriter.clone(op, mapping); + assert(false && "type not supported"); + return {0, 0, 0}; + } else { + assert(false && "version not supported"); + return {0, 0}; + } +} + +bool isLoadFromTensorPtr(triton::LoadOp op) { + return mlir::triton::isTensorPointerType(op.getPtr().getType()); +} + +bool isStoreToTensorPtr(triton::StoreOp op) { + return mlir::triton::isTensorPointerType(op.getPtr().getType()); +} + +Operation *getFirstUser(Value v) { + DenseMap operationId; + v.getParentBlock()->walk( + [&](Operation *op) { operationId[op] = operationId.size(); }); + size_t minId = std::numeric_limits::max(); + Operation *firstUser = nullptr; + for (Operation *user : v.getUsers()) { + assert(operationId.find(user) != operationId.end()); + size_t userId = operationId[user]; + if (userId < minId) { + minId = userId; + firstUser = user; } - rewriter.replaceOp(forOp, newForOp.getResults()); - return success(); } -}; + assert(firstUser); + return firstUser; +} -} // namespace +triton::gpu::SharedEncodingAttr getSharedEncoding(RankedTensorType tensorTy) { + auto blockedLayout = + tensorTy.getEncoding().cast(); + return triton::gpu::SharedEncodingAttr::get( + tensorTy.getContext(), tensorTy.getShape(), blockedLayout.getOrder(), + blockedLayout.getCTALayout(), tensorTy.getElementType()); +} -LogicalResult fixupLoops(ModuleOp mod) { - auto *ctx = mod.getContext(); - mlir::RewritePatternSet patterns(ctx); - patterns.add(ctx); - if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) - return failure(); - return success(); +//===----------------------------------------------------------------------===// +// GraphDumper +//===----------------------------------------------------------------------===// + +GraphDumper::NodeInfo GraphDumper::onValue(Value value) const { + return {{"shape", "box"}, {"style", "filled"}, {"fillcolor", "white"}}; } -// -------------------------------------------------------------------------- // +GraphDumper::NodeInfo GraphDumper::onOperation(Operation *op) const { + return {{"shape", "ellipse"}, {"style", "filled"}, {"fillcolor", "white"}}; +} + +std::string GraphDumper::dump(triton::FuncOp func) const { + llvm::SetVector values; + llvm::SetVector operations; + + func.walk([&](Operation *op) { + operations.insert(op); + for (Value operand : op->getOperands()) + values.insert(operand); + for (Value result : op->getResults()) + values.insert(result); + }); + + std::ostringstream oss; + oss << "// Generated by Triton GraphDumper\n" + << "\n" + << "digraph {\n"; -// TODO: Interface -LogicalResult invertEncoding(Attribute targetEncoding, Operation *op, - Attribute &ret) { - ret = targetEncoding; - if (auto expand_dims = dyn_cast(op)) { - ret = triton::gpu::SliceEncodingAttr::get( - op->getContext(), expand_dims.getAxis(), targetEncoding); + oss << " // Value Nodes\n"; + for (Value value : values) + oss << " " << emitValueNode(value) << "\n"; + oss << "\n"; + + oss << " // Operation Nodes\n"; + for (Operation *op : operations) + oss << " " << emitOperationNode(op) << "\n"; + oss << "\n"; + + oss << " // Edges\n"; + for (Operation *op : operations) { + for (Value operand : op->getOperands()) + oss << " " << emitEdge(getUniqueId(operand), getUniqueId(op)) << "\n"; + for (Value result : op->getResults()) + oss << " " << emitEdge(getUniqueId(op), getUniqueId(result)) << "\n"; } - if (auto reduce = dyn_cast(op)) { - auto sliceEncoding = - targetEncoding.dyn_cast(); - if (!sliceEncoding) - return failure(); - if (sliceEncoding.getDim() != reduce.getAxis()) - return failure(); - ret = sliceEncoding.getParent(); + + oss << "}\n"; + return oss.str(); +} + +void GraphDumper::dumpToFile(triton::FuncOp func, + const std::string &filename) const { + std::ofstream ofs(filename); + ofs << dump(func); +} + +std::string GraphDumper::getShapeStr(const Type &type) const { + std::ostringstream oss; + oss << "["; + if (auto tensorTy = type.dyn_cast()) { + auto shape = tensorTy.getShape(); + for (unsigned i = 0; i < shape.size(); ++i) { + if (i > 0) + oss << ", "; + oss << shape[i]; + } } - if (isa(op)) { - return failure(); + oss << "]"; + return oss.str(); +} + +std::string GraphDumper::getUniqueId(Value value) const { + std::ostringstream oss; + oss << value.getImpl(); + return oss.str(); +} + +std::string GraphDumper::getUniqueId(Operation *op) const { + std::ostringstream oss; + oss << op; + return oss.str(); +} + +std::string GraphDumper::emitNode(const std::string &id, + const GraphDumper::NodeInfo info) const { + std::ostringstream oss; + oss << "\"" << id << "\" ["; + for (auto it = info.begin(); it != info.end(); ++it) { + if (it != info.begin()) + oss << ", "; + oss << it->first << " = \"" << it->second << "\""; } - return success(); + oss << "];"; + return oss.str(); +} + +std::string GraphDumper::emitEdge(const std::string &srcId, + const std::string &destId) const { + std::ostringstream oss; + oss << "\"" << srcId << "\" -> \"" << destId << "\";"; + return oss.str(); +} + +std::string GraphDumper::emitValueNode(Value value) const { + NodeInfo info = onValue(value); + if (info.find("label") == info.end()) { + std::string shapeStr = getShapeStr(value.getType()); + if (auto arg = value.dyn_cast()) + info["label"] = + "BlockArg" + std::to_string(arg.getArgNumber()) + " " + shapeStr; + else + info["label"] = shapeStr; + } + return emitNode(getUniqueId(value), info); +} + +std::string GraphDumper::emitOperationNode(Operation *op) const { + NodeInfo info = onOperation(op); + if (info.find("label") == info.end()) + info["label"] = op->getName().getStringRef().str(); + return emitNode(getUniqueId(op), info); +} + +//===----------------------------------------------------------------------===// +// GraphLayoutMarker +//===----------------------------------------------------------------------===// + +GraphDumper::NodeInfo GraphLayoutMarker::onValue(Value value) const { + std::string color = getColor(value.getType()); + return {{"shape", "box"}, {"style", "filled"}, {"fillcolor", color}}; +} + +std::string GraphLayoutMarker::getColor(const Type &type) const { + if (auto tensorTy = type.dyn_cast()) { + auto layout = tensorTy.getEncoding(); + if (layout.isa()) + return "green"; + else if (layout.isa()) + return "yellow"; + else if (layout.isa()) + return "lightslateblue"; + else if (layout.isa()) + return "orange"; + else if (layout.isa()) + return "orangered"; + else + assert(0 && "Unrecognized layout"); + } else { + return "white"; + } +} +// -------------------------------------------------------------------------- // + +static std::optional inferDstEncoding(triton::ReduceOp op, + Attribute encoding) { + return triton::gpu::SliceEncodingAttr::get(op->getContext(), op.getAxis(), + encoding); +} + +static std::optional inferDstEncoding(triton::ExpandDimsOp op, + Attribute encoding) { + auto sliceEncoding = encoding.dyn_cast(); + if (!sliceEncoding) + return std::nullopt; + if (op.getAxis() != sliceEncoding.getDim()) + return std::nullopt; + return sliceEncoding.getParent(); +} + +static std::optional inferSrcEncoding(triton::ReduceOp op, + Attribute encoding) { + auto sliceEncoding = encoding.dyn_cast(); + if (!sliceEncoding) + return std::nullopt; + if (op.getAxis() != sliceEncoding.getDim()) + return std::nullopt; + return sliceEncoding.getParent(); +} + +static std::optional inferSrcEncoding(triton::ExpandDimsOp op, + Attribute encoding) { + return triton::gpu::SliceEncodingAttr::get(op->getContext(), op.getAxis(), + encoding); +} + +std::optional inferSrcEncoding(Operation *op, Attribute encoding) { + if (auto reduceOp = dyn_cast(op)) + return inferSrcEncoding(reduceOp, encoding); + if (auto expand = dyn_cast(op)) + return inferSrcEncoding(expand, encoding); + if (isa(op)) + return std::nullopt; + return encoding; +} + +std::optional inferDstEncoding(Operation *op, Attribute encoding) { + if (auto reduceOp = dyn_cast(op)) + return inferDstEncoding(reduceOp, encoding); + if (auto expand = dyn_cast(op)) + return inferDstEncoding(expand, encoding); + if (isa(op)) + return std::nullopt; + return encoding; } -bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding) { - // Case 1: A size 1 tensor is not expensive since all threads will load the +bool isExpensiveLoadOrStore(Operation *op) { + // Case 1: Pointer of tensor is always expensive + auto operandType = op->getOperand(0).getType(); + if (triton::isTensorPointerType(operandType)) + return true; + // Case 2a: A size 1 tensor is not expensive since all threads will load the // same if (isSingleValue(op->getOperand(0))) return false; - // Case 2: Tensor of pointers has more threads than elements + // Case 2b: Tensor of pointers has more threads than elements // we can presume a high hit-rate that makes it cheap to load auto ptrType = op->getOperand(0).getType().cast(); auto mod = op->getParentOfType(); @@ -108,7 +316,7 @@ bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { if (!op) return true; if (isa(op)) - return isExpensiveLoadOrStore(op, targetEncoding); + return isExpensiveLoadOrStore(op); if (isa(op)) return triton::gpu::isExpensiveCat(cast(op), targetEncoding); if (isa(op)) return !triton::gpu::isExpensiveCat(cast(op), targetEncoding); - return isa(op); -} - -int simulateBackwardRematerialization( - Operation *initOp, SetVector &processed, - SetVector &layout, llvm::MapVector &toConvert, - Attribute targetEncoding) { - // DFS - std::vector> queue; - queue.emplace_back(initOp, targetEncoding); - // We want to see the effect of converting `initOp` to a new layout - // so we initialize `numCvts = 1`. - int numCvts = 1; - while (!queue.empty()) { - Operation *currOp; - Attribute currLayout; - std::tie(currOp, currLayout) = queue.back(); - queue.pop_back(); - // If the current operation is expensive to rematerialize, - // we stop everything - if (isExpensiveToRemat(currOp, currLayout)) - break; - // A conversion will be removed here (i.e. transferred to operands) - numCvts -= 1; - // Done processing - processed.insert(currOp); - layout.insert(currLayout); - // Add all operands to the queue - for (Value argI : currOp->getOperands()) { - Attribute newEncoding; - // Cannot invert the current encoding for this operand - // we stop everything - if (failed(invertEncoding(currLayout, currOp, newEncoding))) - return INT_MAX; - if (toConvert.count(argI) && toConvert[argI] != newEncoding) - return INT_MAX; - Operation *opArgI = argI.getDefiningOp(); - toConvert.insert({argI, newEncoding}); - // 1. Only convert RankedTensorType - // 2. Skip if there's no defining op - // 3. Skip if the defining op has already been processed - // 4. Skip or the defining op is in a different block - if (!argI.getType().isa() || !opArgI || - processed.contains(opArgI) || - opArgI->getBlock() != currOp->getBlock()) - continue; - // If the conversion can be folded into opArgI then - // we don't count this conversion as expensive - if (canFoldConversion(opArgI, newEncoding)) - continue; - - // We add one expensive conversion for the current operand - numCvts += 1; - queue.emplace_back(opArgI, newEncoding); + if (auto convert = dyn_cast(op)) { + if (targetEncoding.isa()) { + auto srcEncoding = + convert.getOperand().getType().cast().getEncoding(); + if (targetEncoding != srcEncoding) + return false; } + return true; } - // return net number of conversions - return numCvts; + return isa(op); } // @@ -224,128 +384,249 @@ Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, return newOp; } -void rematerializeConversionChain( - const llvm::MapVector &toConvert, - mlir::PatternRewriter &rewriter, SetVector &processed, - IRMapping &mapping) { - SmallVector sortedValues; - SetVector tmp; - for (auto &item : toConvert) { - Value v = item.first; - if (v.getDefiningOp()) - tmp.insert(v.getDefiningOp()); - else - sortedValues.push_back(v); +LogicalResult +getConvertBackwardSlice(Value root, SetVector &slice, + Attribute rootEncoding, + DenseMap &layout, + std::function stopPropagation) { + SmallVector> queue = {{root, rootEncoding}}; + while (!queue.empty()) { + auto [currentValue, encoding] = queue.back(); + queue.pop_back(); + if (!currentValue.getType().isa()) + continue; + // Skip propagating through for op results for now. + // TODO: enable this based on needs. + if (currentValue.getDefiningOp()) + return failure(); + slice.insert(currentValue); + layout[currentValue] = encoding; + if (auto *definingOp = currentValue.getDefiningOp()) { + if (canFoldIntoConversion(definingOp, encoding)) + continue; + if (stopPropagation && stopPropagation(definingOp)) + continue; + if (isa(definingOp)) + return failure(); + for (Value operand : definingOp->getOperands()) { + auto srcEncoding = inferSrcEncoding(definingOp, encoding); + if (!srcEncoding) + return failure(); + if (slice.count(operand) == 0) + queue.push_back({operand, *srcEncoding}); + } + continue; + } + auto blockArg = cast(currentValue); + Block *block = blockArg.getOwner(); + Operation *parentOp = block->getParentOp(); + if (auto forOp = dyn_cast(parentOp)) { + OpOperand &initOperand = forOp.getOpOperandForRegionIterArg(blockArg); + Value yieldOperand = forOp.getBody()->getTerminator()->getOperand( + blockArg.getArgNumber() - forOp.getNumInductionVars()); + queue.push_back({initOperand.get(), encoding}); + queue.push_back({yieldOperand, encoding}); + continue; + } + // TODO: add support for WhileOp and other region types. + return failure(); + } + return success(); +} + +// TODO(thomas): this is duplicated with what is in GPUToLLVM +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape, + ArrayRef order) { + unsigned rank = shape.size(); + assert(rank == order.size()); + auto reordered = reorder(shape, order); + auto reorderedMultiDim = delinearize(b, loc, linear, reordered); + SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; } - tmp = mlir::multiRootTopologicalSort(tmp); - for (Operation *op : tmp) - sortedValues.push_back(op->getResult(0)); - - for (Value currOperand : sortedValues) { - Value origOperand = currOperand; - // unpack information - Attribute targetLayout = toConvert.lookup(currOperand); - // rematerialize the operand if necessary - Operation *currOperation = currOperand.getDefiningOp(); - if (processed.contains(currOperation)) { - Operation *newOperation = - cloneWithInferType(rewriter, currOperation, mapping); - newOperation->moveAfter(currOperation); - currOperation = newOperation; - currOperand = currOperation->getResult(0); + return multiDim; +} + +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape) { + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + if (rank == 1) { + multiDim[0] = linear; + } else { + Value remained = linear; + for (auto &&en : llvm::enumerate(shape.drop_back())) { + auto dimSize = b.create(loc, en.value(), 32); + multiDim[en.index()] = b.create(loc, remained, dimSize); + remained = b.create(loc, remained, dimSize); + } + multiDim[rank - 1] = remained; + } + return multiDim; +} + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order) { + return linearize(b, loc, reorder(multiDim, order), + reorder(shape, order)); +} + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape) { + auto rank = multiDim.size(); + Value linear = b.create(loc, 0, 32); + if (rank > 0) { + linear = multiDim.back(); + for (auto [dim, dimShape] : + llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) { + Value dimSize = b.create(loc, dimShape, 32); + linear = b.create( + loc, b.create(loc, linear, dimSize), dim); } - // compute target type for the layout cast - auto currType = currOperand.getType().cast(); - auto newType = RankedTensorType::get( - currType.getShape(), currType.getElementType(), targetLayout); - auto newOperand = rewriter.create( - currOperand.getLoc(), newType, currOperand); - if (currOperation) - newOperand->moveAfter(currOperation); - else { - Block *block = currOperand.cast().getOwner(); - newOperand->moveBefore(block, block->begin()); + } + return linear; +} + +std::optional getWSAgentId(Operation *op) { + int prevAgentId = -1; + if (auto attr = op->getAttrOfType("async_agent")) { + for (auto agentId : attr.getValues()) { + assert(prevAgentId == -1 && "support at most one agent id"); + prevAgentId = agentId; } - mapping.map(origOperand, newOperand); } + if (prevAgentId == -1) + return std::nullopt; + return prevAgentId; } -LogicalResult canMoveOutOfLoop(BlockArgument arg, - SmallVector &cvts) { - auto parentOp = arg.getOwner()->getParentOp(); - // Don't move if arg is defined in a while loop - if (isa(parentOp)) - return failure(); - // Skip if arg is not defined in scf.for - if (!isa(parentOp)) - return success(); - auto forOp = cast(parentOp); - // We only move `iterArg` out of the loop if - // 1. There is no conversion - // 2. There is only a single conversion - // 3. Moving this conversion out of the loop will not generate any extra - // non-removable conversion - SetVector cvtTypes; - SetVector others; - auto oldType = arg.getType().cast(); - for (auto user : arg.getUsers()) { - if (isa(user)) { - // Don't move if the conversion target is a dot operand or shared memory - auto newType = user->getResults()[0].getType().cast(); - if (oldType.getEncoding().isa() && - newType.getEncoding().isa()) { +std::optional getWSRoleId(Operation *op) { + if (!op->hasAttr("agent.mutex_role")) + return std::nullopt; + return op->getAttrOfType("agent.mutex_role").getInt(); +} + +void setRoleId(Operation *op, int roleId) { + auto attr = IntegerAttr::get(IntegerType::get(op->getContext(), 32), roleId); + op->setAttr("agent.mutex_role", attr); +} + +namespace { + +/// Detect dead arguments in scf.for op by assuming all the values are dead and +/// propagate liveness property. +struct ForOpDeadArgElimination : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const final { + Block &block = *forOp.getBody(); + auto yieldOp = cast(block.getTerminator()); + // Assume that nothing is live at the beginning and mark values as live + // based on uses. + DenseSet aliveValues; + SmallVector queue; + // Helper to mark values as live and add them to the queue of value to + // propagate if it is the first time we detect the value as live. + auto markLive = [&](Value val) { + if (!forOp->isAncestor(val.getParentRegion()->getParentOp())) + return; + if (aliveValues.insert(val).second) + queue.push_back(val); + }; + // Mark all yield operands as live if the associated forOp result has any + // use. + for (auto result : llvm::enumerate(forOp.getResults())) { + if (!result.value().use_empty()) + markLive(yieldOp.getOperand(result.index())); + } + if (aliveValues.size() == forOp.getNumResults()) + return failure(); + // Operations with side-effects are always live. Mark all theirs operands as + // live. + block.walk([&](Operation *op) { + if (!isa(op) && !wouldOpBeTriviallyDead(op)) { + for (Value operand : op->getOperands()) + markLive(operand); + } + }); + // Propagate live property until reaching a fixed point. + while (!queue.empty()) { + Value value = queue.pop_back_val(); + if (auto nestedFor = value.getDefiningOp()) { + auto result = value.cast(); + OpOperand &forOperand = nestedFor.getOpOperandForResult(result); + markLive(forOperand.get()); + auto nestedYieldOp = + cast(nestedFor.getBody()->getTerminator()); + Value nestedYieldOperand = + nestedYieldOp.getOperand(result.getResultNumber()); + markLive(nestedYieldOperand); continue; } - if (newType.getEncoding().isa()) { - if (newType.getEncoding() - .cast() - .getVec() == 1) - continue; + if (auto nestedIf = value.getDefiningOp()) { + auto result = value.cast(); + for (scf::YieldOp nestedYieldOp : + {nestedIf.thenYield(), nestedIf.elseYield()}) { + Value nestedYieldOperand = + nestedYieldOp.getOperand(result.getResultNumber()); + markLive(nestedYieldOperand); + } + continue; } - cvts.emplace_back(user); - cvtTypes.insert(newType); - } else - others.insert(user); - } - // First condition - if (cvts.empty()) - return success(); - if (cvtTypes.size() == 1) { - // Second condition - if (others.empty()) - return success(); - // Third condition - part 1: - // If the other or the cvt is in the different block, we cannot push the - // conversion forward or backward - for (auto *cvt : cvts) { - if (cvt->getBlock() != forOp.getBody()) - return failure(); - } - auto targetEncoding = cvtTypes.front().getEncoding(); - for (auto *other : others) { - // Third condition - part 2: - // If the other non-cvt op is in the different block, we cannot push the - // conversion forward or backward - if (other->getBlock() != forOp.getBody()) - return failure(); - // Third condition - part 3: - // %0 (enc1) = cvt %arg (enc0) - // other %0 (enc1), %1 (enc0) => other %0 (enc1), %1 (enc1) - // Check if %2 (enc1) = cvt %1 (enc0) can be eliminated - SetVector processed; - SetVector layout; - llvm::MapVector toConvert; - for (auto operand : other->getOperands()) { - auto argOp = operand.getDefiningOp(); - if (argOp && !isa(argOp) && - simulateBackwardRematerialization(argOp, processed, layout, - toConvert, targetEncoding) > 0) + if (Operation *def = value.getDefiningOp()) { + // TODO: support while ops. + if (isa(def)) return failure(); + for (Value operand : def->getOperands()) + markLive(operand); + continue; + } + // If an argument block is live then the associated yield operand and + // forOp operand are live. + auto arg = value.cast(); + if (auto forOwner = dyn_cast(arg.getOwner()->getParentOp())) { + if (arg.getArgNumber() < forOwner.getNumInductionVars()) + continue; + unsigned iterIdx = arg.getArgNumber() - forOwner.getNumInductionVars(); + Value yieldOperand = + forOwner.getBody()->getTerminator()->getOperand(iterIdx); + markLive(yieldOperand); + markLive(forOwner.getIterOperands()[iterIdx]); } } + SmallVector deadArg; + for (auto yieldOperand : llvm::enumerate(yieldOp->getOperands())) { + if (aliveValues.contains(yieldOperand.value())) + continue; + if (yieldOperand.value() == block.getArgument(yieldOperand.index() + 1)) + continue; + deadArg.push_back(yieldOperand.index()); + } + if (deadArg.empty()) + return failure(); + rewriter.updateRootInPlace(forOp, [&]() { + // For simplicity we just change the dead yield operand to use the + // associated argument and leave the operations and argument removal to + // dead code elimination. + for (unsigned deadArgIdx : deadArg) { + BlockArgument arg = block.getArgument(deadArgIdx + 1); + yieldOp.setOperand(deadArgIdx, arg); + } + }); return success(); } - return failure(); +}; + +} // namespace + +void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); } } // namespace mlir diff --git a/lib/Dialect/TritonNvidiaGPU/CMakeLists.txt b/lib/Dialect/TritonNvidiaGPU/CMakeLists.txt new file mode 100644 index 000000000000..9f57627c321f --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt b/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt new file mode 100644 index 000000000000..99f2ef6b702b --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(TritonNvidiaGPUIR + Dialect.cpp + Ops.cpp + Traits.cpp + Types.cpp + + DEPENDS + TritonNvidiaGPUTableGen + TritonNvidiaGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + TritonIR + TritonGPUIR +) diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp new file mode 100644 index 000000000000..0a982ce0572a --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.cpp.inc" + +using namespace mlir; +using namespace mlir::triton::nvidia_gpu; + +//===----------------------------------------------------------------------===// +// Attribute methods +//===----------------------------------------------------------------------===// +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc" + +//===----------------------------------------------------------------------===// + +void TritonNvidiaGPUDialect::initialize() { + registerTypes(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc" +#include "triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.cpp.inc" + >(); +} + +// verify TritonNvidiaGPU ops +LogicalResult +TritonNvidiaGPUDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + // TODO: fill this. + return success(); +} diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp new file mode 100644 index 000000000000..f11f0278133e --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "mlir/IR/Builders.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc" + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +///--- DotAsyncOp --- +mlir::LogicalResult DotAsyncOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the accumulator + auto accTy = operands[2].getType().cast(); + inferredReturnTypes.push_back(accTy); + + // verify encodings + auto aEnc = operands[0].getType().cast().getEncoding(); + auto bEnc = operands[1].getType().cast().getEncoding(); + auto retEnc = accTy.getEncoding(); + if (aEnc) { + assert(bEnc); + Dialect &dialect = aEnc.getDialect(); + auto interface = dyn_cast(&dialect); + if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) + return mlir::failure(); + if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) + return mlir::failure(); + } + return mlir::success(); +} + +///--- Async related ops --- +void GetAgentIdOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state) { + build(builder, state, builder.getI32Type()); +} + +void CreateTokenOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state, uint32_t num) { + auto tokenType = TokenType::get(builder.getContext()); + auto resultType = RankedTensorType::get({num}, tokenType); + build(builder, state, resultType, num); +} + +void GetMutexRoleIdOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state, uint32_t num) { + build(builder, state, builder.getI32Type(), num); +} + +void CreateMutexOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state) { + build(builder, state, MutexType::get(builder.getContext())); +} + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Traits.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Traits.cpp new file mode 100644 index 000000000000..8360eea33244 --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/IR/Traits.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "triton/Dialect/TritonNvidiaGPU/IR/Traits.h" +#include "triton/Analysis/Utility.h" + +mlir::LogicalResult +mlir::OpTrait::impl::verifySource1IsSharedEncoding(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 2))) + return failure(); + + if (!mlir::triton::gpu::isSharedEncoding(op->getOperand(1))) + return op->emitOpError() << "requires operand 1 to be shared encoding"; + + return success(); +}; diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Types.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Types.cpp new file mode 100644 index 000000000000..326f4948a113 --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/IR/Types.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h" +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton::nvidia_gpu; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc" + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void ::mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc" + >(); +} diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000000..53674ebfc673 --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt @@ -0,0 +1,23 @@ +add_mlir_dialect_library(TritonNvidiaGPUTransforms + MaterializeLoadStore.cpp + PlanCTA.cpp + WSDecomposing.cpp + WSFeasibilityChecking.cpp + WSPipeline.cpp + WSMutex.cpp + WSMaterialization.cpp + WSFixupMissingAttrs.cpp + FenceInsertion.cpp + RewriteTensorPointer.cpp + Utility.cpp + + DEPENDS + TritonNvidiaGPUTransformsIncGen + + LINK_LIBS PUBLIC + TritonIR + TritonGPUIR + TritonGPUTransforms + TritonNvidiaGPUIR + MLIRTransformUtils +) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp new file mode 100644 index 000000000000..3175fbbfb018 --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -0,0 +1,73 @@ +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "llvm/Support/Debug.h" + +//===----------------------------------------------------------------------===// +// +// This pass works after all other passes, inserting fences to ensure that +// memory operations are properly ordered acorss genric and async proxy. +// +//===----------------------------------------------------------------------===// + +using namespace mlir; +namespace tt = ::mlir::triton; +namespace ttg = ::mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +using ::mlir::triton::gpu::SharedEncodingAttr; + +namespace { + +struct FenceInsertionPass + : public TritonGPUFenceInsertionBase { + +public: + FenceInsertionPass() = default; + FenceInsertionPass(int computeCapability) { + this->computeCapability = computeCapability; + } + // TODO: support more patterns to insert fences + // only support insertion between convert layout ops and dot ops to protect + // flashattention + void runOnOperation() override { + // Only insert fences for compute capability 9.0 + if (computeCapability < 90) + return; + ModuleOp mod = getOperation(); + mod.walk([&](Operation *op) { + if (isa(op)) { + auto a = op->getOperand(0); + auto b = op->getOperand(1); + auto mmaEncoding = op->getResult(0) + .getType() + .cast() + .getEncoding() + .dyn_cast(); + auto isHopperEncoding = mmaEncoding && mmaEncoding.isHopper(); + if (isHopperEncoding && (isa(a.getDefiningOp()) && + ttg::isSharedEncoding(a)) || + (isa(b.getDefiningOp()) && + ttg::isSharedEncoding(b))) { + + // TODO: check whether cluster fence is needed + OpBuilder builder(op); + builder.create(op->getLoc(), + false /*bCluster*/); + } + } + }); + } +}; + +} // namespace + +std::unique_ptr +mlir::createTritonNvidiaGPUFenceInsertionPass(int computeCapability) { + return std::make_unique(computeCapability); +} diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/MaterializeLoadStore.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/MaterializeLoadStore.cpp new file mode 100644 index 000000000000..d8429e95c54b --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/MaterializeLoadStore.cpp @@ -0,0 +1,212 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" +#include + +//===----------------------------------------------------------------------===// +// +// This pass works after pipeline pass, converts the remaining tt.LoadOp taking +// ptr as input into ttg.InsertSliceAsyncOp and emit proper barriers +// +//===----------------------------------------------------------------------===// + +using namespace mlir; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::getCTALayout; +using ::mlir::triton::gpu::MmaEncodingAttr; +using ::mlir::triton::gpu::SharedEncodingAttr; + +namespace { + +struct MaterializeLoadStorePass + : public MaterializeLoadStoreBase { + +public: + MaterializeLoadStorePass() = default; + MaterializeLoadStorePass(int numWarps, int computeCapability) { + this->numWarps = numWarps; + this->computeCapability = computeCapability; + } + + void runOnOperation() override { + SmallVector worklists; + getOperation()->walk([&](mlir::triton::LoadOp load) -> void { + if (isLoadFromTensorPtr(load)) { + worklists.push_back(load); + } + }); + for (auto load : worklists) { + materializeLoadTilePtr(load); + } + + SmallVector storeOpWorklists; + getOperation()->walk([&](mlir::triton::StoreOp store) -> void { + if (isStoreToTensorPtr(store)) { + storeOpWorklists.push_back(store); + } + }); + for (auto store : storeOpWorklists) { + materializeStoreTilePtr(store); + } + } + +private: + void materializeLoadTilePtr(mlir::triton::LoadOp load); + void materializeStoreTilePtr(mlir::triton::StoreOp store); +}; + +void MaterializeLoadStorePass::materializeLoadTilePtr( + mlir::triton::LoadOp load) { + if (computeCapability < 90) + return; + if (!::triton::tools::getBoolEnv("ENABLE_TMA")) + return; + auto loc = load.getLoc(); + OpBuilder b(load); + auto loadTy = load.getType().dyn_cast(); + auto loadShape = loadTy.getShape(); + auto CTASplitNum = ttg::getCTASplitNum(loadTy.getEncoding()); + auto shapePerSlice = ttg::getShapePerCTA(CTASplitNum, loadShape); + auto elemTy = loadTy.getElementType(); + assert(loadTy); + SmallVector bufferShape(loadShape.begin(), loadShape.end()); + bufferShape.insert(bufferShape.begin(), 1); + + auto sharedEncoding = getSharedEncoding(loadTy); + auto bufferTy = RankedTensorType::get(bufferShape, elemTy, sharedEncoding); + Value buffer = b.create(loc, bufferTy); + unsigned elems = std::accumulate(shapePerSlice.begin(), shapePerSlice.end(), + 1, std::multiplies{}); + elems *= (elemTy.getIntOrFloatBitWidth() / 8); + auto mBarrierTy = mlir::triton::PointerType::get(b.getIntegerType(64), 3); + Value mBarrier = b.create(loc, mBarrierTy, 1); + Value _0 = b.create(loc, 0, 32); + Value threadId = b.create(loc); + Value pred = + b.create(loc, arith::CmpIPredicate::eq, threadId, _0); + b.create(loc, mBarrier, pred, /*remoteCtaId*/ nullptr, + /*trackAsyncOp*/ false, elems); + Value inserted = b.create( + loc, bufferTy, load.getPtr(), buffer, + /*index*/ _0, mBarrier, load.getMask(), load.getOther(), load.getCache(), + load.getEvict(), load.getIsVolatile(), + /*axis*/ 0); + auto extractedTy = RankedTensorType::get(loadShape, elemTy, sharedEncoding); + Value extracted = b.create( + loc, extractedTy, inserted, + SmallVector{b.getI64IntegerAttr(0), b.getI64IntegerAttr(0), + b.getI64IntegerAttr(0)}, + SmallVector{b.getI64IntegerAttr(1), + b.getI64IntegerAttr(loadShape[0]), + b.getI64IntegerAttr(loadShape[1])}, + SmallVector{b.getI64IntegerAttr(1), b.getI64IntegerAttr(1), + b.getI64IntegerAttr(1)}); + + Value phase = b.create(loc, 0, 1); + b.create(loc, mBarrier, phase); + Value newValue = + b.create(loc, load.getType(), extracted); + load.getResult().replaceAllUsesWith(newValue); + load->erase(); +} + +void MaterializeLoadStorePass::materializeStoreTilePtr( + mlir::triton::StoreOp store) { + if (computeCapability < 90 || !::triton::tools::getBoolEnv("ENABLE_TMA")) + return; + auto loc = store.getLoc(); + OpBuilder builder(store); + auto value = store.getValue(); + auto dst = store.getPtr(); + + auto cvtOp = llvm::dyn_cast_or_null( + value.getDefiningOp()); + if (cvtOp) { + auto srcTy = cvtOp.getOperand().getType().cast(); + auto dstTy = cvtOp.getResult().getType().cast(); + auto elemTy = srcTy.getElementType(); + auto srcMmaLayout = srcTy.getEncoding().dyn_cast(); + auto dstBlockedLayout = dstTy.getEncoding().dyn_cast(); + auto truncFOP = llvm::dyn_cast_or_null( + cvtOp.getOperand().getDefiningOp()); + unsigned numElems = ttg::getTotalElemsPerThread(srcTy); + auto inOrd = ttg::getOrder(srcTy.getEncoding()); + auto outOrd = ttg::getOrder(dstTy.getEncoding()); + if (srcMmaLayout && srcMmaLayout.isHopper() && dstBlockedLayout && + truncFOP && elemTy.getIntOrFloatBitWidth() == 16 && numElems >= 16 && + inOrd == outOrd) { + builder.create(loc, dst, cvtOp.getOperand()); + builder.create(loc); + builder.create(loc, 0); + store->erase(); + return; + } + } + + auto *ctx = store.getContext(); + auto storeTy = value.getType().dyn_cast(); + assert(storeTy); + auto storeElemTy = storeTy.getElementType(); + auto ctaLayout = getCTALayout(storeTy.getEncoding()); + auto storeShape = storeTy.getShape(); + SmallVector bufferShape(storeShape.begin(), storeShape.end()); + auto rank = storeShape.size(); + // The order of smem should be consistent with gmem. + auto makeTensorPtrOp = getMakeTensorPtrOp(dst); + SmallVector sharedOrder; + for (auto o : makeTensorPtrOp.getOrder()) { + sharedOrder.emplace_back(o); + } + auto sharedEncoding = SharedEncodingAttr::get(ctx, storeShape, sharedOrder, + ctaLayout, storeElemTy); + auto bufferTy = + RankedTensorType::get(bufferShape, storeElemTy, sharedEncoding); + Value cvt = builder.create(loc, bufferTy, value); + builder.create(loc, dst, cvt); + builder.create(loc); + builder.create(loc, 0); + store->erase(); +} + +} // anonymous namespace + +std::unique_ptr +mlir::createTritonNvidiaGPUMaterializeLoadStorePass(int numWarps, + int computeCapability) { + return std::make_unique(numWarps, + computeCapability); +} diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp new file mode 100644 index 000000000000..d79da1ee9961 --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp @@ -0,0 +1,1024 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +using namespace mlir; +namespace ttg = ::mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; + +// TODO: use ConvertLayoutOp +using CastOp = ::mlir::UnrealizedConversionCastOp; + +unsigned getNumUsers(Value value) { + return std::distance(value.user_begin(), value.user_end()); +} + +Type replaceLayout(const Type &type, const Attribute &newLayout) { + Type curType = type; + auto ptrTy = curType.dyn_cast(); + if (ptrTy) + curType = ptrTy.getPointeeType(); + if (auto tensorTy = curType.dyn_cast()) + curType = RankedTensorType::get(tensorTy.getShape(), + tensorTy.getElementType(), newLayout); + if (ptrTy) + curType = triton::PointerType::get(curType, ptrTy.getAddressSpace()); + return curType; +} + +Attribute replaceCTALayout(Attribute layout, llvm::ArrayRef shape, + const ttg::CTALayoutAttr &newCTALayout) { + if (auto blockedLayout = layout.dyn_cast()) { + return ttg::BlockedEncodingAttr::get( + layout.getContext(), shape, blockedLayout.getSizePerThread(), + blockedLayout.getOrder(), ttg::getNumWarpsPerCTA(layout), 32, + newCTALayout); + } else if (auto sliceLayout = layout.dyn_cast()) { + return ttg::SliceEncodingAttr::get( + layout.getContext(), sliceLayout.getDim(), + replaceCTALayout(sliceLayout.getParent(), shape, newCTALayout)); + } else { + // Other layouts are generated by passes after PlanCTAPass + assert(0 && "replaceCTALayout not implemented"); + } +} + +class CTAPlanner { +public: + CTAPlanner(ttng::ClusterInfo *clusterInfo_); + ~CTAPlanner(); + + void run(triton::FuncOp &funcOp); + +private: + CastOp markBackward(CastOp cast) const; + CastOp markForward(CastOp cast) const; + bool isBackward(CastOp cast) const; + bool isForward(CastOp cast) const; + + void setTiling(llvm::ArrayRef CTAsPerCGA); + bool processDot(triton::FuncOp &funcOp); + bool processReduce(triton::FuncOp &funcOp); + void processStoreLikeOps(triton::FuncOp &funcOp); + + bool propagate(CastOp cast); + bool propagateBackward(CastOp cast); + bool propagateForward(CastOp cast); + + void eraseCastOp(CastOp cast); + void eraseCastOpFromQueue(CastOp cast); + void eraseCastOpsFromQueue(llvm::ArrayRef casts); + + void insertCasts(Operation *op, llvm::ArrayRef newOperandLayouts, + llvm::ArrayRef newResultLayouts); + void eliminateAdjacentCasts(CastOp cast0, CastOp cast1); + + bool isLoadStoreOp(Operation *op) const; + bool processLoadStore(Operation *op, Attribute layout); + + bool isElementwiseOp(Operation *op) const; + bool processElementwise(Operation *op, Attribute layout); + + bool processConstant(arith::ConstantOp constant, Attribute layout); + bool processSplat(triton::SplatOp splat, Attribute layout); + bool processMakeRange(triton::MakeRangeOp makeRange, Attribute layout); + bool processMakeTensorPtr(triton::MakeTensorPtrOp makeTensorPtr, + Attribute layout); + + bool processBroadcast(triton::BroadcastOp broadcast, Attribute layout); + bool processExpandDimsBackward(triton::ExpandDimsOp expandDims, + Attribute newResultLayout); + bool processExpandDimsForward(triton::ExpandDimsOp expandDims, + Attribute newSrcLayout); + + bool processConvertLayoutBackward(ttg::ConvertLayoutOp convertLayout, + CastOp cast); + bool processConvertLayoutForward(ttg::ConvertLayoutOp convertLayout, + CastOp cast); + + bool processIfOp(scf::IfOp ifOp, int index, const Type &newType); + bool processForOp(scf::ForOp forOp, int index, const Type &newType); + + bool processIfOpBackward(scf::IfOp ifOp, CastOp cast); + bool processForOpBackward(scf::ForOp forOp, CastOp cast); + bool processBlockArgBackward(BlockArgument arg, CastOp cast); + bool processForOpForward(scf::ForOp forOp, CastOp cast); + bool processYieldOpForward(scf::YieldOp yieldOp, CastOp cast); + + bool processOpFallback(Operation *op); + + bool processMultiUsersBackward(Value input, CastOp cast); + bool processMultiUsersForward(Value output, CastOp cast); + + // This flag indicates whether clusterInfo needs to be deleted in the + // destructor of CTAPlanner. The flag `ownInfo` is set to false when a + // non-null pointer to clusterInfo is passed to the constructor of CTAPlanner. + // Otherwise, a self-managed ClusterInfo will be created and the ownInfo will + // be set to true. + bool ownInfo; + ttng::ClusterInfo *clusterInfo; + bool tiled; + unsigned step; + unsigned stepUnchanged; + std::queue queue; +}; + +CTAPlanner::CTAPlanner(ttng::ClusterInfo *clusterInfo_) + : ownInfo(false), clusterInfo(clusterInfo_), tiled(false), step(0), + stepUnchanged(0) { + if (clusterInfo == nullptr) { + clusterInfo = new ttng::ClusterInfo(); + ownInfo = true; + } +} + +CTAPlanner::~CTAPlanner() { + if (ownInfo) { + delete clusterInfo; + // Actually not necessary but safer + ownInfo = false; + clusterInfo = nullptr; + } +} + +void CTAPlanner::run(triton::FuncOp &funcOp) { + assert(!tiled && "Please create a new CTAPlanner"); + static const unsigned maxSteps = 10000; + + auto nextStep = [&]() { + ++step; + assert(step < maxSteps && "Maximum number of steps exceeded"); + }; + + processDot(funcOp); + nextStep(); + + processReduce(funcOp); + nextStep(); + + if (!tiled) { + processStoreLikeOps(funcOp); + nextStep(); + } + + while (!queue.empty()) { + CastOp cast = queue.front(); + queue.pop(); + bool changed = propagate(cast); + if (changed) { + stepUnchanged = 0; + } else { + queue.push(cast); + ++stepUnchanged; + } + nextStep(); + } +} + +CastOp CTAPlanner::markBackward(CastOp cast) const { + cast->setAttr("direction", StringAttr::get(cast.getContext(), "backward")); + return cast; +} + +CastOp CTAPlanner::markForward(CastOp cast) const { + cast->setAttr("direction", StringAttr::get(cast.getContext(), "forward")); + return cast; +} + +bool CTAPlanner::isBackward(CastOp cast) const { + return cast->getAttrOfType("direction") == "backward"; +} + +bool CTAPlanner::isForward(CastOp cast) const { + return cast->getAttrOfType("direction") == "forward"; +} + +void CTAPlanner::setTiling(llvm::ArrayRef CTAsPerCGA) { + assert(!tiled && "CTA tiling is already determinted"); + assert(clusterInfo && "ClusterInfo pointer is null"); + assert(CTAsPerCGA.size() <= 3 && "setTiling not implemented"); + if (CTAsPerCGA.size() > 0) + clusterInfo->clusterDimX = CTAsPerCGA[0]; + if (CTAsPerCGA.size() > 1) + clusterInfo->clusterDimY = CTAsPerCGA[1]; + if (CTAsPerCGA.size() > 2) + clusterInfo->clusterDimZ = CTAsPerCGA[2]; + tiled = true; +} + +bool CTAPlanner::processDot(triton::FuncOp &funcOp) { + // TODO: This is a naive implementation and should be refactored + auto getCTATiling = [](int64_t M, int64_t N, int64_t K, + unsigned numCTAs) -> std::pair { + unsigned splitM = std::clamp(M / 64, 1, numCTAs); + unsigned splitN = numCTAs / splitM; + return {splitM, splitN}; + }; + + funcOp.walk([&](triton::DotOp dot) { + MLIRContext *ctx = dot.getContext(); + + auto aTy = dot.getA().getType().cast(); + auto bTy = dot.getB().getType().cast(); + auto dTy = dot.getD().getType().cast(); + + assert(aTy.getEncoding().isa() && + bTy.getEncoding().isa() && + dTy.getEncoding().isa() && + "PlanCTAPass should follow immediately after CoalescePass"); + + auto aLayout = aTy.getEncoding().cast(); + auto bLayout = bTy.getEncoding().cast(); + auto dLayout = dTy.getEncoding().cast(); + + unsigned M = dTy.getShape()[0]; + unsigned N = dTy.getShape()[1]; + unsigned K = aTy.getShape()[1]; + + unsigned splitM, splitN; + std::tie(splitM, splitN) = getCTATiling(M, N, K, ttg::getNumCTAs(dLayout)); + // FIXME: Should consider IR with more than one DotOps + setTiling({splitM, splitN, 1}); + + auto newCTALayout = ttg::CTALayoutAttr::get(ctx, {splitM, splitN}, + {splitM, splitN}, {1, 0}); + auto newDLayout = ttg::BlockedEncodingAttr::get( + ctx, dTy.getShape(), dLayout.getSizePerThread(), dLayout.getOrder(), + ttg::getNumWarpsPerCTA(dLayout), 32, newCTALayout); + auto newALayout = ttg::DotOperandEncodingAttr::get(ctx, aLayout.getOpIdx(), + newDLayout, 0); + auto newBLayout = ttg::DotOperandEncodingAttr::get(ctx, bLayout.getOpIdx(), + newDLayout, 0); + + insertCasts(dot.getOperation(), {newALayout, newBLayout, newDLayout}, + {newDLayout}); + }); + + return true; +} + +bool CTAPlanner::processReduce(triton::FuncOp &funcOp) { + ModuleOp mod = funcOp->getParentOfType(); + unsigned numCTAs = ttg::TritonGPUDialect::getNumCTAs(mod); + + funcOp.walk([&](triton::ReduceOp reduce) { + MLIRContext *context = reduce.getContext(); + Value src = reduce.getOperands()[0]; + unsigned axis = reduce.getAxis(); + + auto srcTy = src.getType().cast(); + auto srcShape = srcTy.getShape(); + auto srcLayout = srcTy.getEncoding(); + + auto rank = srcShape.size(); + auto order = ttg::getOrder(srcLayout); + auto sizePerThread = ttg::getSizePerThread(srcLayout); + auto CTAOrder = ttg::getCTAOrder(srcLayout); + + llvm::SmallVector CTAsPerCGA(rank, 0); + unsigned remainingCTAs = numCTAs; + for (int i = rank - 1; i >= 0; --i) { + unsigned dim = order[i]; + if (dim == axis) { + CTAsPerCGA[dim] = 1; + } else { + CTAsPerCGA[dim] = std::min(srcShape[dim] / sizePerThread[dim], + remainingCTAs); + remainingCTAs /= CTAsPerCGA[dim]; + } + } + + for (int i = rank - 1; i >= 0; --i) { + unsigned dim = order[i]; + if (dim != axis) { + CTAsPerCGA[dim] *= remainingCTAs; + break; + } + } + + llvm::SmallVector CTASplitNum = CTAsPerCGA; + + // If numCTAs > 1 and the only dimension is the reduced dimension, after the + // above two for-loops, CTAsPerCGA = [0] and remainingCTAs = numCTAs. We set + // CTAsPerCGA[0] = numCTAs and keep CTASplitNum[0] = 1 to ensure that no + // cross-CTA reduction is required, although this will introduce duplicated + // calculation + if (remainingCTAs > 0) + CTAsPerCGA[order[rank - 1]] *= remainingCTAs; + + auto CTALayout = + ttg::CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + if (!tiled) + setTiling(CTALayout.getCTAsPerCGA()); + auto newSrcLayout = replaceCTALayout(srcLayout, srcShape, CTALayout); + auto newResultLayout = + ttg::SliceEncodingAttr::get(context, axis, newSrcLayout); + unsigned numOperands = reduce.getNumOperands(); + SmallVector newSrcLayoutVec(numOperands, newSrcLayout); + SmallVector newResultLayoutVec(numOperands, newResultLayout); + + insertCasts(reduce.getOperation(), newSrcLayoutVec, newResultLayoutVec); + }); + return true; +} + +void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) { + assert(!tiled && "CTA tiling is already determinted"); + + llvm::SmallVector stores; + funcOp.walk([&](Operation *op) { + if (llvm::isa( + op)) + stores.push_back(op); + }); + assert(stores.size() > 0 && "Cannot find store-like ops"); + + ttg::CTALayoutAttr CTALayout; + for (Operation *store : stores) { + if (auto tensorTy = + store->getOperand(0).getType().dyn_cast()) { + if (!tiled) { + // Use CTA tiling of the first store-like op as global CTA tiling + CTALayout = ttg::getCTALayout(tensorTy.getEncoding()); + setTiling(CTALayout.getCTAsPerCGA()); + } + auto newLayout = replaceCTALayout(tensorTy.getEncoding(), + tensorTy.getShape(), CTALayout); + processElementwise(store, newLayout); + } + } + + // If all store-like ops are processing scalar values and no ReduceOp is + // found, we can conclude that this is an all-scalar computation, since + // ReduceOp is the only op that converts tensor values to scalar values. + if (!tiled) + setTiling({1, 1, 1}); +} + +bool CTAPlanner::propagate(CastOp cast) { + return isBackward(cast) ? propagateBackward(cast) : propagateForward(cast); +} + +bool CTAPlanner::propagateBackward(CastOp cast) { + Value input = cast.getOperand(0); + Value output = cast.getResult(0); + unsigned numUsers = getNumUsers(input); + if (numUsers == 0) { + assert(0 && "Unreachable branch"); + } else if (numUsers == 1) { + Type outTy = output.getType(); + if (auto ptrTy = outTy.dyn_cast()) + outTy = ptrTy.getPointeeType(); + Attribute layout = outTy.cast().getEncoding(); + Operation *op = input.getDefiningOp(); + if (op == nullptr) { + assert(input.isa() && + "Unexpected Value without defining op"); + processBlockArgBackward(input.cast(), cast); + } else if (auto prevCast = llvm::dyn_cast(op)) { + eliminateAdjacentCasts(prevCast, cast); + } else if (isLoadStoreOp(op)) { + processLoadStore(op, layout); + } else if (isElementwiseOp(op)) { + processElementwise(op, layout); + } else if (auto constant = llvm::dyn_cast(op)) { + processConstant(constant, layout); + } else if (auto splat = llvm::dyn_cast(op)) { + processSplat(splat, layout); + } else if (auto makeRange = llvm::dyn_cast(op)) { + processMakeRange(makeRange, layout); + } else if (auto makeTensorPtr = + llvm::dyn_cast(op)) { + processMakeTensorPtr(makeTensorPtr, layout); + } else if (llvm::isa(op)) { + // ptr operand and result have the same layout, while other operands are + // scalar values + processElementwise(op, layout); + } else if (auto broadcast = llvm::dyn_cast(op)) { + processBroadcast(broadcast, layout); + } else if (auto expandDims = llvm::dyn_cast(op)) { + processExpandDimsBackward(expandDims, layout); + } else if (auto ifOp = llvm::dyn_cast(op)) { + processIfOpBackward(ifOp, cast); + } else if (auto forOp = llvm::dyn_cast(op)) { + processForOpBackward(forOp, cast); + } else if (auto convertLayout = llvm::dyn_cast(op)) { + return processConvertLayoutBackward(convertLayout, cast); + } else { + // Keep original layouts. This may result in a loss of performance. + return processOpFallback(op); + } + return true; + } else { + return processMultiUsersBackward(input, cast); + } +} + +bool CTAPlanner::propagateForward(CastOp cast) { + Value input = cast.getOperand(0); + Value output = cast.getResult(0); + unsigned numUsers = getNumUsers(output); + if (numUsers == 0) { + cast.erase(); + } else if (numUsers == 1) { + Type inTy = input.getType(); + if (auto ptrTy = inTy.dyn_cast()) + inTy = ptrTy.getPointeeType(); + Attribute layout = inTy.cast().getEncoding(); + Operation *op = *output.user_begin(); + if (auto nextCast = llvm::dyn_cast(op)) { + eliminateAdjacentCasts(cast, nextCast); + } else if (isLoadStoreOp(op)) { + processLoadStore(op, layout); + } else if (isElementwiseOp(op)) { + processElementwise(op, layout); + } else if (llvm::isa(op)) { + // ptr operand and result have the same layout, while other operands are + // scalar values + processElementwise(op, layout); + } else if (auto convertLayout = llvm::dyn_cast(op)) { + return processConvertLayoutForward(convertLayout, cast); + } else if (auto forOp = llvm::dyn_cast(op)) { + processForOpForward(forOp, cast); + } else if (auto yieldOp = llvm::dyn_cast(op)) { + processYieldOpForward(yieldOp, cast); + } else { + // Keep original layouts. This may result in a loss of performance. + return processOpFallback(op); + } + } else { + processMultiUsersForward(output, cast); + } + return true; +} + +void CTAPlanner::eraseCastOp(CastOp cast) { + Value output = cast.getResult(0); + assert(getNumUsers(output) == 0 && + "Cannot erase CastOp because it is still in use"); + cast.erase(); +} + +void CTAPlanner::eraseCastOpFromQueue(CastOp cast) { + eraseCastOpsFromQueue({cast}); +} + +void CTAPlanner::eraseCastOpsFromQueue(llvm::ArrayRef casts) { + llvm::DenseSet erased; + for (CastOp cast : casts) { + eraseCastOp(cast); + erased.insert(cast); + } + + decltype(queue) tempQueue; + std::swap(queue, tempQueue); + + // This is only a naive implementation. Should refactor with linked-list. + while (!tempQueue.empty()) { + auto cast = tempQueue.front(); + tempQueue.pop(); + if (!erased.contains(cast)) + queue.push(cast); + } +} + +void CTAPlanner::insertCasts(Operation *op, + llvm::ArrayRef newOperandLayouts, + llvm::ArrayRef newResultLayouts) { + assert(op->getNumOperands() == newOperandLayouts.size() && + "NumOperands mismatched"); + assert(op->getNumResults() == newResultLayouts.size() && + "NumResults mismatched"); + + Location loc = op->getLoc(); + OpBuilder builder(op->getContext()); + + builder.setInsertionPoint(op); + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + Value operand = op->getOperand(i); + auto operandTy = operand.getType(); + if (triton::isTensorOrTensorPointerType(operandTy)) { + operandTy = replaceLayout(operandTy, newOperandLayouts[i]); + auto cast = markBackward(builder.create(loc, operandTy, operand)); + op->setOperand(i, cast.getResult(0)); + queue.push(cast); + } + } + + builder.setInsertionPointAfter(op); + for (unsigned i = 0; i < op->getNumResults(); ++i) { + Value result = op->getResult(i); + auto resultTy = result.getType(); + if (triton::isTensorOrTensorPointerType(resultTy)) { + resultTy = replaceLayout(resultTy, newResultLayouts[i]); + auto cast = + markForward(builder.create(loc, result.getType(), result)); + result.setType(resultTy); + result.replaceAllUsesExcept(cast.getResult(0), cast.getOperation()); + queue.push(cast); + } + } +} + +void CTAPlanner::eliminateAdjacentCasts(CastOp cast0, CastOp cast1) { + assert(cast0.getResult(0) == cast1.getOperand(0) && + "The two casts are not adjacent"); + assert(isForward(cast0) && isBackward(cast1) && + "Expected pattern of adjacent casts: forward + backward"); + + Value input = cast0.getOperand(0); + Value output = cast1.getResult(0); + + if (input.getType() == output.getType()) { + output.replaceAllUsesWith(input); + eraseCastOpsFromQueue({cast1, cast0}); + } else { + OpBuilder builder(cast1.getOperation()); + auto cvt = builder.create(cast1.getLoc(), + output.getType(), input); + output.replaceAllUsesWith(cvt.getResult()); + eraseCastOpsFromQueue({cast1, cast0}); + } +} + +bool CTAPlanner::isLoadStoreOp(Operation *op) const { + return llvm::isa(op); +} + +bool CTAPlanner::processLoadStore(Operation *op, Attribute layout) { + // Special logic for: + // LoadOp -> SliceLayout + // Transform to: + // LoadOp -> originalLayout -> ConvertLayout(DSmem) -> SliceLayout + if (auto sliceLayout = layout.dyn_cast()) { + auto dim = sliceLayout.getDim(); + auto CTAsPerCGA = ttg::getCTAsPerCGA(sliceLayout.getParent()); + if (CTAsPerCGA[dim] > 1) { + // Find an input or output value of LoadOp or StoreOp to get its layout + Value val = + op->getNumResults() > 0 ? op->getResult(0) : op->getOperand(0); + Attribute originalLayout = + val.getType().cast().getEncoding(); + // Insert casts using originalLayout. Adjacent casts will be eliminated + // and generate a ConvertLayoutOp with DSmem access + return processLoadStore(op, originalLayout); + } + } + + auto CTALayout = ttg::getCTALayout(layout); + + llvm::SmallVector newOperandLayouts; + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + auto type = op->getOperand(i).getType(); + if (auto ptrTy = type.dyn_cast()) + type = ptrTy.getPointeeType(); + auto tensorTy = type.cast(); + auto newLayout = replaceCTALayout(tensorTy.getEncoding(), + tensorTy.getShape(), CTALayout); + newOperandLayouts.push_back(newLayout); + } + + llvm::SmallVector newResultLayouts; + for (unsigned i = 0; i < op->getNumResults(); ++i) { + auto type = op->getResult(i).getType(); + if (auto ptrTy = type.dyn_cast()) + type = ptrTy.getPointeeType(); + auto tensorTy = type.cast(); + auto newLayout = replaceCTALayout(tensorTy.getEncoding(), + tensorTy.getShape(), CTALayout); + newResultLayouts.push_back(newLayout); + } + + insertCasts(op, newOperandLayouts, newResultLayouts); + return true; +} + +bool CTAPlanner::isElementwiseOp(Operation *op) const { + if (llvm::isa(op)) + return true; + if (llvm::isa(op)) + return true; + if (llvm::isa(op)) + return true; + if (auto externElementwiseOp = dyn_cast(op)) + return externElementwiseOp.getPure(); + if (llvm::isa(op)) + return true; + return false; +} + +bool CTAPlanner::processElementwise(Operation *op, Attribute layout) { + llvm::SmallVector newOperandLayouts(op->getNumOperands(), layout); + llvm::SmallVector newResultLayouts(op->getNumResults(), layout); + insertCasts(op, newOperandLayouts, newResultLayouts); + return true; +} + +bool CTAPlanner::processConstant(arith::ConstantOp constant, Attribute layout) { + if (auto tensorTy = + constant.getResult().getType().dyn_cast()) { + if (auto attr = constant.getValue().dyn_cast()) { + + auto newTensorTy = RankedTensorType::get( + tensorTy.getShape(), tensorTy.getElementType(), layout); + constant.setValueAttr( + SplatElementsAttr::get(newTensorTy, attr.getSplatValue())); + } + } + insertCasts(constant.getOperation(), {}, {layout}); + return true; +} + +bool CTAPlanner::processSplat(triton::SplatOp splat, Attribute layout) { + insertCasts(splat.getOperation(), {{}}, {layout}); + return true; +} + +bool CTAPlanner::processMakeRange(triton::MakeRangeOp makeRange, + Attribute layout) { + insertCasts(makeRange.getOperation(), {}, {layout}); + return true; +} + +bool CTAPlanner::processMakeTensorPtr(triton::MakeTensorPtrOp makeTensorPtr, + Attribute layout) { + // All inputs of `makeTensorPtr` are scalar types + llvm::SmallVector dummyInAttrs(makeTensorPtr.getNumOperands(), {}); + insertCasts(makeTensorPtr.getOperation(), dummyInAttrs, {layout}); + return true; +} + +bool CTAPlanner::processBroadcast(triton::BroadcastOp broadcast, + Attribute layout) { + insertCasts(broadcast.getOperation(), {layout}, {layout}); + return true; +} + +bool CTAPlanner::processExpandDimsBackward(triton::ExpandDimsOp expandDims, + Attribute newResultLayout) { + auto newSrcLayout = ttg::SliceEncodingAttr::get( + newResultLayout.getContext(), expandDims.getAxis(), newResultLayout); + insertCasts(expandDims.getOperation(), {newSrcLayout}, {newResultLayout}); + return true; +} + +bool CTAPlanner::processExpandDimsForward(triton::ExpandDimsOp expandDims, + Attribute newSrcLayout) { + assert(0 && "processExpandDimsForward not implemented yet"); + return true; +} + +bool CTAPlanner::processConvertLayoutBackward( + ttg::ConvertLayoutOp convertLayout, CastOp cast) { + Value src = convertLayout.getSrc(); + Value result = convertLayout.getResult(); + assert(getNumUsers(result) == 1 && + "Expect to call processMultiUsersBackward first"); + result.replaceAllUsesWith(src); + convertLayout.erase(); + queue.push(cast); + return true; +} + +bool CTAPlanner::processConvertLayoutForward(ttg::ConvertLayoutOp convertLayout, + CastOp cast) { + Value src = convertLayout.getSrc(); + Value result = convertLayout.getResult(); + assert(getNumUsers(src) == 1 && + "Expect to call processMultiUsersForward first"); + src.setType(result.getType()); + result.replaceAllUsesWith(src); + convertLayout.erase(); + queue.push(cast); + return true; +} + +bool CTAPlanner::processIfOp(scf::IfOp ifOp, int index, const Type &newType) { + // Check index + assert(index < ifOp.getNumResults() && "Invalid result index of IfOp"); + assert(index < ifOp.thenYield().getNumOperands() && + "Invalid operand index of YieldOp"); + assert(index < ifOp.elseYield().getNumOperands() && + "Invalid operand index of YieldOp"); + + Location loc = ifOp.getLoc(); + OpBuilder builder(ifOp.getContext()); + + // Insert forward cast after ifOp + Value result = ifOp.getResult(index); + builder.setInsertionPointAfter(ifOp.getOperation()); + auto newCast = + markForward(builder.create(loc, result.getType(), result)); + result.setType(newType); + result.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation()); + queue.push(newCast); + + // Insert backward casts before yield + for (scf::YieldOp yield : {ifOp.thenYield(), ifOp.elseYield()}) { + Value yieldSrc = yield.getOperand(index); + builder.setInsertionPoint(yield.getOperation()); + newCast = markBackward(builder.create(loc, newType, yieldSrc)); + yield->setOperand(index, newCast.getResult(0)); + queue.push(newCast); + } + + return true; +} + +bool CTAPlanner::processForOp(scf::ForOp forOp, int index, + const Type &newType) { + Block *body = forOp.getBody(); + auto yield = llvm::cast(forOp.getBody()->getTerminator()); + + // Check index + assert(index + forOp.getNumControlOperands() < forOp.getNumOperands() && + "Invalid operand index of ForOp"); + assert(index + forOp.getNumInductionVars() < body->getNumArguments() && + "Invalid block arg index of ForOp"); + assert(index < yield.getNumOperands() && "Invalid operand index of YieldOp"); + assert(index < forOp.getNumResults() && "Invalid result index of IfOp"); + + Location loc = forOp.getLoc(); + OpBuilder builder(forOp.getContext()); + + // Insert backward cast before forOp + OpOperand &operand = + forOp->getOpOperand(index + forOp.getNumControlOperands()); + builder.setInsertionPoint(forOp.getOperation()); + auto newCast = + markBackward(builder.create(loc, newType, operand.get())); + operand.set(newCast.getResult(0)); + queue.push(newCast); + + // Insert forward cast after block arg + Value arg = body->getArgument(index + forOp.getNumInductionVars()); + builder.setInsertionPointToStart(body); + newCast = markForward(builder.create(loc, arg.getType(), arg)); + arg.setType(newType); + arg.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation()); + queue.push(newCast); + + // Insert backward cast before yield + Value yieldSrc = yield.getOperand(index); + builder.setInsertionPoint(yield.getOperation()); + newCast = markBackward(builder.create(loc, newType, yieldSrc)); + yield->setOperand(index, newCast.getResult(0)); + queue.push(newCast); + + // Insert forward cast after forOp + Value result = forOp.getResult(index); + builder.setInsertionPointAfter(forOp.getOperation()); + newCast = markForward(builder.create(loc, result.getType(), result)); + result.setType(newType); + result.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation()); + queue.push(newCast); + + return true; +} + +int findResultIndex(Operation *op, Value result) { + for (int i = 0; i < op->getNumResults(); ++i) + if (op->getResult(i) == result) + return i; + assert(0 && "Invalid index of op result"); + return -1; +} + +bool CTAPlanner::processIfOpBackward(scf::IfOp ifOp, CastOp cast) { + int index = findResultIndex(ifOp.getOperation(), cast.getOperand(0)); + auto newType = cast.getResult(0).getType(); + return processIfOp(ifOp, index, newType); +} + +bool CTAPlanner::processForOpBackward(scf::ForOp forOp, CastOp cast) { + int index = findResultIndex(forOp.getOperation(), cast.getOperand(0)); + auto newType = cast.getResult(0).getType(); + return processForOp(forOp, index, newType); +} + +bool CTAPlanner::processBlockArgBackward(BlockArgument arg, CastOp cast) { + if (auto forOp = llvm::dyn_cast(arg.getOwner()->getParentOp())) { + int index = int(arg.getArgNumber()) - forOp.getNumInductionVars(); + auto newType = cast.getResult(0).getType(); + return processForOp(forOp, index, newType); + } else { + assert(0 && "Unexpected parent op of block argument"); + return true; + } +} + +bool CTAPlanner::processForOpForward(scf::ForOp forOp, CastOp cast) { + int index = cast.getResult(0).use_begin()->getOperandNumber() - + forOp.getNumControlOperands(); + auto newType = cast.getOperand(0).getType(); + return processForOp(forOp, index, newType); +} + +bool CTAPlanner::processYieldOpForward(scf::YieldOp yieldOp, CastOp cast) { + int index = cast.getResult(0).use_begin()->getOperandNumber(); + auto newType = cast.getOperand(0).getType(); + if (auto ifOp = llvm::dyn_cast(yieldOp->getParentOp())) + return processIfOp(ifOp, index, newType); + else if (auto forOp = llvm::dyn_cast(yieldOp->getParentOp())) + return processForOp(forOp, index, newType); + else + assert(0 && "Unexpected parent op of YieldOp"); + return true; +} + +bool CTAPlanner::processOpFallback(Operation *op) { + Location loc = op->getLoc(); + OpBuilder builder(op->getContext()); + + builder.setInsertionPoint(op); + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + Value operand = op->getOperand(i); + auto operandTy = operand.getType(); + if (triton::isTensorOrTensorPointerType(operandTy)) { + auto cast = markBackward(builder.create(loc, operandTy, operand)); + op->setOperand(i, cast.getResult(0)); + queue.push(cast); + } + } + + builder.setInsertionPointAfter(op); + for (unsigned i = 0; i < op->getNumResults(); ++i) { + Value result = op->getResult(i); + auto resultTy = result.getType(); + if (triton::isTensorOrTensorPointerType(resultTy)) { + auto cast = markForward(builder.create(loc, resultTy, result)); + result.replaceAllUsesExcept(cast.getResult(0), cast.getOperation()); + queue.push(cast); + } + } + + return true; +} + +bool CTAPlanner::processMultiUsersBackward(Value input, CastOp cast) { + Location loc = input.getLoc(); + OpBuilder builder(input.getContext()); + + llvm::DenseMap> typeToIndices; + for (OpOperand &operand : input.getUses()) { + auto brotherCast = llvm::dyn_cast(operand.getOwner()); + if (!brotherCast) { + if (stepUnchanged <= queue.size()) + return false; + builder.setInsertionPoint(operand.getOwner()); + brotherCast = markBackward( + builder.create(loc, cast.getResult(0).getType(), input)); + auto newCast = markForward(builder.create( + loc, input.getType(), brotherCast.getResult(0))); + operand.set(newCast.getResult(0)); + queue.push(brotherCast); + queue.push(newCast); + } + auto type = brotherCast.getResult(0).getType(); + typeToIndices[type].push_back(brotherCast); + } + + bool first = true; + for (auto it : typeToIndices) { + Type &type = it.first; + llvm::SmallVector &casts = it.second; + Value newInput = input; + if (!first) { + if (Operation *defOp = input.getDefiningOp()) { + builder.setInsertionPointAfter(defOp); + Operation *clonedOp = builder.clone(*defOp); + newInput = clonedOp->getResult(0); + } else { + assert(0 && "Layout conflict for block arg"); // TODO + } + } + first = false; + if (Operation *defOp = newInput.getDefiningOp()) { + builder.setInsertionPointAfter(defOp); + } else { + assert(newInput.isa() && + "Unexpected Value without defining op"); + builder.setInsertionPointToStart( + newInput.cast().getOwner()); + } + auto newCast = markBackward(builder.create(loc, type, newInput)); + queue.push(newCast); + auto newResult = newCast.getResult(0); + for (CastOp &brotherCast : casts) { + brotherCast.getResult(0).replaceAllUsesWith(newResult); + eraseCastOpFromQueue(brotherCast); + } + } + return true; +} + +bool CTAPlanner::processMultiUsersForward(Value castResult, CastOp cast) { + Value castSrc = cast.getOperand(0); + + Location loc = cast.getLoc(); + OpBuilder builder(cast.getContext()); + builder.setInsertionPointAfter(cast.getOperation()); + + while (!castResult.use_empty()) { + auto newCast = + markForward(builder.create(loc, castResult.getType(), castSrc)); + castResult.use_begin()->set(newCast.getResult(0)); + queue.push(newCast); + } + + eraseCastOp(cast); + return true; +} + +struct PlanCTAPass : public TritonGPUPlanCTAPassBase { + PlanCTAPass(ttng::ClusterInfo *clusterInfo_ = nullptr) + : clusterInfo(clusterInfo_) {} + + void runOnOperation() override { + ModuleOp mod = getOperation(); + + // Skip PlanCTAPass when numCTAs == 1 + if (ttg::TritonGPUDialect::getNumCTAs(mod) == 1) + return; + + mod.walk([&](triton::FuncOp funcOp) { + CTAPlanner planner(clusterInfo); + planner.run(funcOp); + + // FIXME: Clone funcOp so that the IR change can be identified after + // PlanCTAPass. Without this, the change after PlanCTAPass will not be + // displayed when MLIR_ENABLE_DUMP=1. This is not reasonable and should + // be fixed later. + OpBuilder builder(funcOp); + builder.clone(*funcOp.getOperation()); + funcOp.erase(); + }); + } + + ttng::ClusterInfo *clusterInfo; +}; + +} // namespace + +std::unique_ptr +mlir::createTritonNvidiaGPUPlanCTAPass(ttng::ClusterInfo *clusterInfo) { + return std::make_unique(clusterInfo); +} + +/* TODO + * - Use ConvertLayoutOp instead of UnrealizedConversionCastOp. + * - Move PlanCTAPass to the front of CoalescePass. + * - Design better tiling strategy for DotOp and ReduceOp. + * - Consider cases where there are more than one DotOps. + * - Use better data structure for erasing CastOps from queue (linked list?). + * - Process eliminable CastOps in higher priority. + * - Fix the clone func bug in PlanCTAPass::runOnOperation. + * - Add some comments to introduce the overall idea of this pass. + * - Add some lit tests for this pass. + */ diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp new file mode 100644 index 000000000000..e13cf8bd9179 --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp @@ -0,0 +1,839 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "mlir/Pass/Pass.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +#include +#include + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { +bool isDivisible(Value v, unsigned divisor) { + if (auto op = v.getDefiningOp()) { + return op.getValue().dyn_cast().getValue().getZExtValue() % + divisor == + 0; + } + if (v.getDefiningOp() && + isa(v.getDefiningOp())) { + return isDivisible(v.getDefiningOp()->getOperand(0), divisor); + } else if (v.getParentBlock()->isEntryBlock() && v.isa()) { + BlockArgument blockArg = v.cast(); + Operation *parentOp = blockArg.getOwner()->getParentOp(); + auto func = dyn_cast(parentOp); + assert(func); + if (auto attr = func.getArgAttrOfType(blockArg.getArgNumber(), + "tt.max_divisibility")) + return attr.getValue().getZExtValue() % divisor == 0; + return false; + } else if (v.getParentBlock()->isEntryBlock() && (!v.isa())) { + // in entryblock but not BlockArgument + return isDivisible(v.getDefiningOp()->getOperand(0), divisor); + } else if (!v.getParentBlock()->isEntryBlock()) { + // in non-entryblock + return isDivisible(v.getDefiningOp()->getOperand(0), divisor); + } else { + llvm::report_fatal_error( + "Operand of `MakeTensorPtrOp` is not the function's argument"); + return false; + } +} + +bool shouldRemove(tt::MakeTensorPtrOp &op, int computeCapability) { + if (computeCapability < 90 || !::triton::tools::getBoolEnv("ENABLE_TMA")) + return true; + auto resType = op.getResult() + .getType() + .cast() + .getPointeeType() + .cast(); + auto elemType = resType.getElementType(); + auto ord = op.getOrder(); + auto stride = op.getStrides(); + auto shape = ttg::getShapePerCTA(resType); + // TMA load/store requires the box dimension to be more than 32 bytes. + // Because we only support 32B-swizzle, 64B-swizzle and 128B-swizzleon for + // now. Remove this constraint when we support non-swizzle smem. + bool boxDimSwizzle = + shape[ord[0]] >= (256 / elemType.getIntOrFloatBitWidth()); + // we only support TMA load with 2D tensor for now. + // TMA load/store requires the stride to be divisible by 16 bytes. + bool strideDivisible = false; + if (stride.size() == 2) + strideDivisible = + isDivisible(stride[ord[1]], 128 / elemType.getIntOrFloatBitWidth()); + bool enableTMA = ::triton::tools::getBoolEnv("ENABLE_TMA"); + return !(boxDimSwizzle && strideDivisible && enableTMA); +} + +// TODO: When encoding exists use triton::gpu::CmpIOp as arith::CmpIOp doesn't +// play well with encoding attributes. Move back to arith::CmpIOp when this pass +// moves back to triton IR level. +Value createCmpOp(OpBuilder &builder, Location loc, RankedTensorType type, + arith::CmpIPredicate pred, Value lhs, Value rhs) { + if (type.getEncoding()) + return builder.create(loc, type, pred, lhs, rhs); + return builder.create(loc, type, pred, lhs, rhs); +} + +/// An additional struct to record the meta information of operations +/// with tensor pointers +struct RewritedInfo { +private: + Value base; + SmallVector shape; + SmallVector strides; + SmallVector offsets; + ArrayRef tensorShape; + Attribute layout; + + // A cache to avoid generating the same offset with range + DenseMap cachedOffsetWithRange; + + template + SmallVector insertOne(ArrayRef vec, unsigned axis) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + axis, 1); + return res; + } + + // Example: order = [ 0, 2, 1, 3], dim = 2 + // resOrder = [2, 0, 3, 1, 4] + SmallVector insertOrder(ArrayRef order, + unsigned axis) const { + SmallVector resOrder(order.begin(), order.end()); + for (unsigned i = 0; i < resOrder.size(); ++i) + if (resOrder[i] >= axis) + ++resOrder[i]; + resOrder.insert(resOrder.begin(), axis); + return resOrder; + } + +public: + RewritedInfo() = default; + + RewritedInfo(const RewritedInfo &other) = default; + + RewritedInfo(Value base, const SmallVector &shape, + const SmallVector &strides, + const SmallVector &offsets, + const ArrayRef &tensorShape, Attribute layout) + : base(base), shape(shape), strides(strides), offsets(offsets), + tensorShape(tensorShape), layout(layout) { + assert(shape.size() == strides.size() && shape.size() == offsets.size() && + shape.size() == tensorShape.size()); + } + + unsigned int length() const { return shape.size(); } + + Value getOffset(unsigned i) { return offsets[i]; } + + SmallVector getOffsets() { return offsets; } + + void setOffset(unsigned i, Value newOffset) { + offsets[i] = newOffset; + cachedOffsetWithRange.clear(); + } + + void setOffsets(const SmallVector &newOffsets) { + offsets = newOffsets; + cachedOffsetWithRange.clear(); + } + + void setEncoding(Attribute newLayout) { layout = newLayout; } + + Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc, + unsigned i) { + if (cachedOffsetWithRange.count(i)) + return cachedOffsetWithRange[i]; + + // Add range + auto indexI32RowType = + RankedTensorType::get({tensorShape[i]}, builder.getI32Type(), layout); + auto indexRowType = + RankedTensorType::get({tensorShape[i]}, builder.getI64Type(), layout); + Value splatOffset = + builder.create(loc, indexRowType, offsets[i]); + Value range = builder.create(loc, indexI32RowType, 0, + tensorShape[i]); + Value i64Range = builder.create(loc, indexRowType, range); + + // Expand dimensions + Value expandedResult = + builder.create(loc, splatOffset, i64Range); + for (int axis = 0; axis < tensorShape.size(); ++axis) { + if (axis == i) + continue; + + if (layout) { + auto argEncoding = layout.cast(); + auto retSizePerThread = insertOne(argEncoding.getSizePerThread(), axis); + auto retThreadsPerWarp = + insertOne(argEncoding.getThreadsPerWarp(), axis); + auto retWarpsPerCTA = insertOne(argEncoding.getWarpsPerCTA(), axis); + auto retOrder = insertOrder(argEncoding.getOrder(), axis); + + auto argCTALayout = argEncoding.getCTALayout(); + auto retCTAsPerCGA = insertOne(argCTALayout.getCTAsPerCGA(), axis); + auto retCTASplitNum = insertOne(argCTALayout.getCTASplitNum(), axis); + auto retCTAOrder = insertOrder(argCTALayout.getCTAOrder(), axis); + + auto retCTALayout = ttg::CTALayoutAttr::get( + loc.getContext(), retCTAsPerCGA, retCTASplitNum, retCTAOrder); + + auto retEncoding = ttg::BlockedEncodingAttr::get( + loc.getContext(), retSizePerThread, retThreadsPerWarp, + retWarpsPerCTA, retOrder, retCTALayout); + + auto newArgEncoding = + ttg::SliceEncodingAttr::get(loc.getContext(), axis, retEncoding); + auto newArgType = RankedTensorType::get(indexRowType.getShape(), + indexRowType.getElementType(), + newArgEncoding); + Value newArg = builder.create(loc, newArgType, + expandedResult); + expandedResult = builder.create(loc, newArg, axis); + } else + expandedResult = + builder.create(loc, expandedResult, axis); + } + + return cachedOffsetWithRange[i] = expandedResult; + } + + Value generatePtr(OpBuilder &builder, const Location &loc) { + assert(tensorShape.size() == offsets.size() && + tensorShape.size() == strides.size()); + auto ptrType = base.getType().cast(); + auto ptrTensorType = RankedTensorType::get(tensorShape, ptrType, layout); + + // Generate offsets per dimension + Value ptr = builder.create(loc, ptrTensorType, base); + for (unsigned i = 0; i < tensorShape.size(); ++i) { + auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); + // We must splat strides into the expanded shape not a row for retaining + // the divisibility information given by strides + Value splatStride = builder.create( + loc, offsetWithRange.getType(), strides[i]); + Value offsetWithStride = + builder.create(loc, offsetWithRange, splatStride); + auto offsetType = offsetWithRange.getType().cast(); + auto indexTensorType = RankedTensorType::get( + tensorShape, offsetType.getElementType(), offsetType.getEncoding()); + Value broadcasted = builder.create(loc, indexTensorType, + offsetWithStride); + if (offsetType.getEncoding() != ptrTensorType.getEncoding()) { + auto newArgType = + RankedTensorType::get(tensorShape, offsetType.getElementType(), + ptrTensorType.getEncoding()); + broadcasted = + builder.create(loc, newArgType, broadcasted); + } + // Add to the pointer + ptr = builder.create(loc, ptrTensorType, ptr, broadcasted); + } + + return ptr; + } + + Value generateMask(OpBuilder &builder, const Location &loc, + const std::optional> &boundaryCheck) { + if (!boundaryCheck.has_value() || boundaryCheck.value().empty()) + return {}; + + // Generate mask per dimension + auto maskTensorType = + RankedTensorType::get(tensorShape, builder.getI1Type(), layout); + Value mask; + for (auto i : boundaryCheck.value()) { + auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); + auto offsetType = offsetWithRange.getType().cast(); + RankedTensorType cmpTensorType = RankedTensorType::get( + offsetType.getShape(), builder.getI1Type(), offsetType.getEncoding()); + + // Compare with lower bound + Value lowerBound = builder.create( + loc, 0, offsetType.getElementType()); + Value splatLowerBound = builder.create( + loc, offsetWithRange.getType(), lowerBound); + Value cmpLower = + createCmpOp(builder, loc, cmpTensorType, arith::CmpIPredicate::sge, + offsetWithRange, splatLowerBound); + + // Compare with upper bound + Value splatUpperBound = + builder.create(loc, offsetWithRange.getType(), shape[i]); + Value cmpUpper = + createCmpOp(builder, loc, cmpTensorType, arith::CmpIPredicate::slt, + offsetWithRange, splatUpperBound); + + // And and broadcast + Value andResult = builder.create(loc, cmpLower, cmpUpper); + if (offsetType.getEncoding() != maskTensorType.getEncoding()) { + auto newArgType = + RankedTensorType::get(offsetType.getShape(), builder.getI1Type(), + maskTensorType.getEncoding()); + andResult = + builder.create(loc, newArgType, andResult); + } + + Value broadcasted = + builder.create(loc, maskTensorType, andResult); + + // And up all results + if (!mask) { + mask = broadcasted; + } else { + mask = builder.create(loc, mask, broadcasted); + } + } + + return mask; + } + + Value generateOther(OpBuilder &builder, const Location &loc, + const std::optional &padding) { + if (!padding.has_value()) + return Value(); + + // Create element attribute + auto elementType = base.getType().cast().getPointeeType(); + auto otherTensorType = + RankedTensorType::get(tensorShape, elementType, layout); + + // Set zero padding value + TypedAttr attr = + elementType.isIntOrIndex() + ? builder.getIntegerAttr(elementType, 0).cast() + : builder.getFloatAttr(elementType, 0).cast(); + + // Float NaN padding case + if (padding.value() == tt::PaddingOption::PAD_NAN) { + assert(!elementType.isIntOrIndex()); + auto apNaN = llvm::APFloat::getNaN( + attr.cast().getValue().getSemantics()); + attr = builder.getFloatAttr(elementType, apNaN); + } + + // Create tensor + Value constant = builder.create(loc, attr); + return builder.create(loc, otherTensorType, constant); + } +}; +} // namespace + +class TritonGPURewriteTensorPointerPass + : public TritonGPURewriteTensorPointerBase< + TritonGPURewriteTensorPointerPass> { +private: + int computeCapability; + DenseMap rewritedInfo; + +public: + explicit TritonGPURewriteTensorPointerPass(int computeCapability) + : computeCapability(computeCapability) {} + + static bool needRewrite(Operation *op, const DenseSet &valueToRemove) { + if (auto ifOp = dyn_cast(op)) { + if (op->getNumResults() == 0) + return false; + Operation *thenYield = ifOp.thenYield().getOperation(); + if (!ifOp.getElseRegion().empty()) { + Operation *elseYield = ifOp.elseYield().getOperation(); + for (unsigned i = 0; i < thenYield->getNumOperands(); ++i) { + bool thenNeedRewrite = valueToRemove.count(thenYield->getOperand(i)); + bool elseNeedRewrite = valueToRemove.count(elseYield->getOperand(i)); + assert(!(thenNeedRewrite ^ elseNeedRewrite) && + "For IfOp, operand(i) of thenYield and operand(i) of " + "elseYield should be either all need rewrite or all not"); + } + } + op = thenYield; + } + return std::any_of(op->getOperands().begin(), op->getOperands().end(), + [&valueToRemove](Value operand) { + return tt::isTensorPointerType(operand.getType()) && + valueToRemove.count(operand); + }); + } + + static SmallVector + generateNewOperands(const SmallVector &oldOperands, unsigned index, + const SmallVector &newValues) { + assert(index < oldOperands.size()); + SmallVector newOperands; + for (int i = 0; i < index; ++i) + newOperands.push_back(oldOperands[i]); + for (auto value : newValues) + newOperands.push_back(value); + for (auto i = index + 1; i < oldOperands.size(); ++i) + newOperands.push_back(oldOperands[i]); + return newOperands; + } + + Operation *rewriteMakeTensorPtrOp(OpBuilder &builder, tt::MakeTensorPtrOp op, + std::stack &eraser, + const DenseSet &valueToRemove) { + if (!valueToRemove.count(op.getResult())) + return nullptr; + // Save info for later use + auto ptrType = op.getResult().getType().cast(); + auto tensorType = ptrType.getPointeeType().cast(); + + // Cast I32 offsets into I64 + SmallVector i64Offsets; + for (auto offset : op.getOffsets()) { + auto i64Offset = builder.create( + op.getLoc(), builder.getI64Type(), offset); + i64Offsets.push_back(i64Offset); + } + + // Save information + rewritedInfo[op.getResult()] = + RewritedInfo(op.getBase(), op.getShape(), op.getStrides(), i64Offsets, + tensorType.getShape(), tensorType.getEncoding()); + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteAdvanceOp(OpBuilder &builder, tt::AdvanceOp op, + std::stack &eraser, + const DenseSet &valueToRemove) { + if (!valueToRemove.count(op.getResult())) { + return nullptr; + } + // Get info from previous results + assert(rewritedInfo.count(op.getPtr())); + auto info = rewritedInfo[op.getPtr()]; + + // Calculate new offsets + assert(info.length() == op.getOffsets().size()); + SmallVector newOffsets; + for (int i = 0; i < info.length(); ++i) { + Value i64Offset = builder.create( + op.getLoc(), builder.getI64Type(), op.getOffsets()[i]); + Value newOffset = builder.create( + op.getLoc(), info.getOffset(i), i64Offset); + newOffsets.push_back(newOffset); + } + + // Save info for later use + info.setOffsets(newOffsets); + rewritedInfo[op.getResult()] = info; + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteLoadStoreOp(OpBuilder &builder, Operation *op, + std::stack &eraser, + const DenseSet &valueToRemove) { + if (!valueToRemove.count(op->getOperand(0))) + return nullptr; + + // We only have to rewrite load/stores with tensor pointers + auto ptr = op->getOperand(0); + if (!tt::isTensorPointerType(ptr.getType())) + return nullptr; + + // Get info from previous results + assert(rewritedInfo.count(ptr)); + auto info = rewritedInfo[ptr]; + + // Load/store with tensor pointers implicitly will check the bound while + // accessing memory, so we should set `mask` and `other` (according to the + // padding). Also note that load with tensor pointers do not have `mask` and + // `other` while building IR from Python AST + std::optional> boundaryCheck; + if (auto loadOp = dyn_cast(op)) { + assert(!loadOp.getMask() && !loadOp.getOther()); + boundaryCheck = loadOp.getBoundaryCheck(); + if (auto valueType = + dyn_cast(loadOp.getResult().getType())) + info.setEncoding(valueType.getEncoding()); + } else if (auto storeOp = dyn_cast(op)) { + assert(!storeOp.getMask()); + boundaryCheck = storeOp.getBoundaryCheck(); + if (auto valueType = + dyn_cast(storeOp.getValue().getType())) + info.setEncoding(valueType.getEncoding()); + } + + // Generate new `ptr`, `mask` and `other` + auto newPtr = info.generatePtr(builder, op->getLoc()); + auto newMask = info.generateMask(builder, op->getLoc(), boundaryCheck); + Value newOther; + if (auto loadOp = dyn_cast(op)) + newOther = info.generateOther(builder, op->getLoc(), loadOp.getPadding()); + + // Create a new operation + if (auto loadOp = dyn_cast(op)) { + auto newResult = builder.create( + loadOp.getLoc(), loadOp.getResult().getType(), newPtr, newMask, + newOther, loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + op->getResult(0).replaceAllUsesWith(newResult); + } else if (auto storeOp = dyn_cast(op)) { + builder.create(storeOp.getLoc(), newPtr, storeOp.getValue(), + newMask, storeOp.getCache(), + storeOp.getEvict()); + } + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteForOp(OpBuilder &builder, scf::ForOp op, + std::stack &eraser, + DenseSet &valueToRemove) { + // Generate new iteration operands and set rewrited information + SmallVector oldIterOperands = op.getIterOperands(); + SmallVector newIterOperands = op.getIterOperands(); + for (unsigned i = 0, oldI = 0, size = op.getNumIterOperands(); i < size; + ++i, ++oldI) { + if (!tt::isTensorPointerType(newIterOperands[i].getType())) + continue; + if (!valueToRemove.count(newIterOperands[i])) + continue; + + // Expand the tensor pointer into offsets + assert(rewritedInfo.count(newIterOperands[i])); + auto info = rewritedInfo[newIterOperands[i]]; + newIterOperands = + generateNewOperands(newIterOperands, i, info.getOffsets()); + i += info.length() - 1; + size += info.length() - 1; + } + + // Rebuild the loop type + auto newForOp = builder.create(op.getLoc(), op.getLowerBound(), + op.getUpperBound(), op.getStep(), + newIterOperands); + + // Create value mapping. Note that for tensor pointers, we use identity + // mapping. It may refer to a value in the old loop, but we will rewrite it + // later + IRMapping mapping; + for (unsigned i = 0, oldI = 0; oldI < op.getNumIterOperands(); + ++i, ++oldI) { + auto oldRegionIterArg = op.getRegionIterArg(oldI); + if (tt::isTensorPointerType(oldRegionIterArg.getType()) && + valueToRemove.count(oldIterOperands[oldI])) { + // Pass rewrited info inside + assert(rewritedInfo.count(oldIterOperands[oldI])); + auto info = rewritedInfo[oldIterOperands[oldI]]; + mapping.map(oldRegionIterArg, oldRegionIterArg); + for (unsigned j = 0; j < info.length(); ++j) + info.setOffset(j, newForOp.getRegionIterArg(i + j)); + rewritedInfo[oldRegionIterArg] = info; + i += info.length() - 1; + } else { + mapping.map(oldRegionIterArg, newForOp.getRegionIterArg(i)); + } + } + mapping.map(op.getInductionVar(), newForOp.getInductionVar()); + + // Clone body + builder.setInsertionPointToStart(newForOp.getBody()); + for (Operation &opInFor : *op.getBody()) { + Operation *newOp = builder.clone(opInFor, mapping); + for (unsigned i = 0; i < opInFor.getNumResults(); ++i) { + if (valueToRemove.count(opInFor.getResult(i))) + valueToRemove.insert(newOp->getResult(i)); + mapping.map(opInFor.getResult(i), newOp->getResult(i)); + } + } + + // supported nested scf.for ops + for (auto &[k, v] : mapping.getValueMap()) + if (valueToRemove.find(k) != valueToRemove.end()) + valueToRemove.insert(v); + + // Replace later usages + assert(op.getNumResults() == op.getNumIterOperands()); + for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) { + auto oldResult = op.getResult(oldI); + if (tt::isTensorPointerType(oldResult.getType()) && + valueToRemove.count(oldIterOperands[oldI])) { + // Pack new offsets into rewrited info + assert(rewritedInfo.count(oldIterOperands[oldI])); + auto info = rewritedInfo[oldIterOperands[oldI]]; + for (unsigned j = 0; j < info.length(); ++j) + info.setOffset(j, newForOp.getResult(i + j)); + i += info.length() - 1; + rewritedInfo[oldResult] = info; + } else { + oldResult.replaceAllUsesWith(newForOp.getResult(i)); + } + } + + // Erase later + eraser.push(op); + return newForOp; + } + + Operation *rewriteYieldOp(OpBuilder &builder, scf::YieldOp op, + std::stack &eraser, + const DenseSet &valueToRemove) { + // Replace tensor pointers with offsets + SmallVector newOperands = op->getOperands(); + for (unsigned i = 0, size = op.getNumOperands(); i < size; ++i) { + if (!tt::isTensorPointerType(newOperands[i].getType())) + continue; + if (!valueToRemove.count(newOperands[i])) + continue; + + assert(rewritedInfo.count(newOperands[i])); + auto info = rewritedInfo[newOperands[i]]; + newOperands = generateNewOperands(newOperands, i, info.getOffsets()); + i += info.length() - 1; + size += info.length() - 1; + } + op->setOperands(newOperands); + + // No need to erase + return nullptr; + } + + Operation *rewriteIfOp(OpBuilder &builder, scf::IfOp op, + std::stack &eraser, + DenseSet &valueToRemove) { + auto thenYieldOp = op.thenYield(); + assert(op.getNumResults() == thenYieldOp.getNumOperands()); + SmallVector results = thenYieldOp.getOperands(); + + // get new result types + SmallVector newRetTypes; + for (unsigned i = 0; i < results.size(); ++i) { + if (!tt::isTensorPointerType(results[i].getType()) || + !valueToRemove.count(results[i])) { + newRetTypes.push_back(results[i].getType()); + continue; + } + auto makeTensorPtrOp = getMakeTensorPtrOp(results[i]); + assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + auto info = rewritedInfo[makeTensorPtrOp.getResult()]; + for (unsigned j = 0; j < info.length(); ++j) { + newRetTypes.push_back(builder.getI64Type()); + } + } + + // create and clone new IfOp + bool hasElse = !op.getElseRegion().empty(); + scf::IfOp newOp = builder.create(op.getLoc(), newRetTypes, + op.getCondition(), hasElse); + IRMapping mapping; + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + mapping.map(op->getOperand(i), newOp->getOperand(i)); + } + auto rematerialize = [&](Block *block) { + for (Operation &opInIf : block->getOperations()) { + auto newOp = builder.clone(opInIf, mapping); + } + }; + builder.setInsertionPointToStart(newOp.thenBlock()); + rematerialize(op.thenBlock()); + if (hasElse) { + builder.setInsertionPointToStart(newOp.elseBlock()); + rematerialize(op.elseBlock()); + } + + // supported nested ops + for (auto &[k, v] : mapping.getValueMap()) + if (valueToRemove.find(k) != valueToRemove.end()) + valueToRemove.insert(v); + + // update rewritedInfo + unsigned oldResIdx = 0, newResIdx = 0; + while (oldResIdx < results.size()) { + if (!tt::isTensorPointerType(results[oldResIdx].getType()) || + !valueToRemove.count(results[oldResIdx])) { + oldResIdx++; + newResIdx++; + } else { + auto makeTensorPtrOp = getMakeTensorPtrOp(results[oldResIdx]); + assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + auto info = rewritedInfo[makeTensorPtrOp.getResult()]; + for (unsigned j = 0; j < info.length(); ++j) { + info.setOffset(j, newOp->getResult(newResIdx++)); + } + rewritedInfo[op.getResult(oldResIdx)] = info; + oldResIdx++; + } + } + + eraser.push(op); + return newOp; + } + + Operation *rewriteOp(Operation *op, std::stack &eraser, + DenseSet &valueToRemove) { + OpBuilder builder(op); + + // Rewrite `make_tensor_ptr` and `advance` and make a tensor of pointers + // Rewriting functions return the next operation to visit, if there is no + // next one, simply return `nullptr` + std::pair rewrited; + if (auto makeTensorPtrOp = dyn_cast(op)) { + return rewriteMakeTensorPtrOp(builder, makeTensorPtrOp, eraser, + valueToRemove); + } else if (auto advanceOp = dyn_cast(op)) { + return rewriteAdvanceOp(builder, advanceOp, eraser, valueToRemove); + } else if (isa(op) || isa(op)) { + return rewriteLoadStoreOp(builder, op, eraser, valueToRemove); + } else if (op->getDialect()->getNamespace() == "scf" || + op->getDialect()->getNamespace() == "cf") { + if (!needRewrite(op, valueToRemove)) + return op; + + if (auto forOp = dyn_cast(op)) { + return rewriteForOp(builder, forOp, eraser, valueToRemove); + } else if (auto yieldOp = dyn_cast(op)) { + return rewriteYieldOp(builder, yieldOp, eraser, valueToRemove); + } else if (auto ifOp = dyn_cast(op)) { + return rewriteIfOp(builder, ifOp, eraser, valueToRemove); + } else { + llvm_unreachable("Currently we only support tensor pointer usages " + "inside a `scf::ForOp` or `scf::IfOp`, others such as " + "`scf::WhileOp`, `cf::BranchOp` or `cf::CondBranchOp` " + "are not supported yet"); + } + } + + // Otherwise return the original one + return op; + } + + void visitOperation(Operation *op, std::stack &eraser, + DenseSet &valueToRemove) { + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + // We need an extra copy because erasing operations may break the + // iterator behavior + SmallVector blockCopy; + for (auto &nestedOp : block) + blockCopy.push_back(&nestedOp); + + // Rewrite and recursively visit + for (auto &nestedOp : blockCopy) { + if (auto newOp = rewriteOp(nestedOp, eraser, valueToRemove)) + visitOperation(newOp, eraser, valueToRemove); + } + } + } + } + + void runOnOperation() override { + ModuleOp mod = getOperation(); + + DenseSet valueToRemove; + mod.walk([&valueToRemove, + computeCapability = this->computeCapability](Operation *op) { + if (auto makeTensorPtrOp = dyn_cast(op)) { + if (shouldRemove(makeTensorPtrOp, computeCapability)) + valueToRemove.insert(op->getResult(0)); + } + if (llvm::isa(op)) { + auto src = op->getOperand(0); + if (tt::isTensorPointerType(src.getType())) { + auto makeTensorPtrOp = getMakeTensorPtrOp(src); + if (shouldRemove(makeTensorPtrOp, computeCapability)) { + valueToRemove.insert(op->getResult(0)); + } + } + } + if (llvm::isa(op)) { + auto src = op->getOperand(0); + if (tt::isTensorPointerType(src.getType())) { + auto makeTensorPtrOp = getMakeTensorPtrOp(src); + if (shouldRemove(makeTensorPtrOp, computeCapability)) + valueToRemove.insert(src); + } + } + if (auto forOp = dyn_cast(op)) { + SmallVector iterOperands = forOp.getIterOperands(); + for (unsigned i = 0, size = forOp.getNumIterOperands(); i < size; ++i) { + if (tt::isTensorPointerType(iterOperands[i].getType())) { + auto makeTensorPtrOp = getMakeTensorPtrOp(iterOperands[i]); + if (shouldRemove(makeTensorPtrOp, computeCapability)) + valueToRemove.insert(iterOperands[i]); + } + } + } else if (auto yieldOp = dyn_cast(op)) { + SmallVector operands = yieldOp->getOperands(); + for (unsigned i = 0, size = yieldOp.getNumOperands(); i < size; ++i) { + if (tt::isTensorPointerType(operands[i].getType())) { + auto makeTensorPtrOp = getMakeTensorPtrOp(operands[i]); + if (shouldRemove(makeTensorPtrOp, computeCapability)) + valueToRemove.insert(operands[i]); + } + } + } + }); + + // NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because + // MLIR does not support one-multiple value mapping. For example, if we use + // `ConversionPatternRewriter`, we can not make a type converter, which + // converts `ptr` into multiple types `ptr<>, int64, int64, ...` + // (containing the base/offsets/strides...). What we can do is to convert + // `ptr` into a single type `Tuple, int64, int64, ...>`. But + // in this way, we also have to define `PackTuple` and `UnpackTuple` + // operations and make a canonicalization pass to optimize, which is much + // So here we recursively build the IR, to be specific, we have to rewrite + // `tt.make_tensor_ptr`, `tt.advance`, `tt.load`, `tt.store`, + // `scf.for` (tensor pointer usages may be in a loop fashion) + std::stack eraser; + visitOperation(getOperation(), eraser, valueToRemove); + + // The operation could not be erased during visit, because they may have + // later usages, so we erase after visit + rewritedInfo.clear(); + valueToRemove.clear(); + while (!eraser.empty()) { + auto op = eraser.top(); + eraser.pop(); + op->erase(); + } + } +}; + +std::unique_ptr +mlir::createTritonGPURewriteTensorPointerPass(int computeCapability) { + return std::make_unique(computeCapability); +} diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp new file mode 100644 index 000000000000..d962aea0c02e --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp @@ -0,0 +1,546 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" +#include + +namespace mlir { + +namespace ttg = triton::gpu; + +namespace { + +bool knownSafeToIgnoreRegion(Operation *op) { + return isa(op); +} + +// Suppose the kernel has following structure: +// ``` +// scf.for(...) { +// compute_0(i) +// barrier(...) +// compute_1(i) +// } +// ``` +// Due to the barrier between compute_0(i) and compute_1(i), we +// can not pre-compute compute_0(i+1) before compute_1(i) semantically. +// In some case, it may be still functionally correct to pre-compute +// compute_0(i+1) while it's very hard to prove it at compile time. +// +// Here we use a simple strategy: skip auto wrap specialize those kernels that +// use global barriers. +// +// Another remaining question is how to detect barrier in a triton program. +// There is not a barrier op in triton yet. It's usually implemented using +// atomic_* ops. Hence we simply detect if there are some atomc_* ops. It may +// miss some auto-WS opportunities and we leave it for the future to improve it. +bool hasUnsafeBarrier(triton::FuncOp funcOp) { + return funcOp + ->walk([](Operation *op) { + if (isa(op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }) + .wasInterrupted(); +} + +// Assigns `dependentSet` and returns ok if the analysis is successful. +// We do not support dependency analysis across load/store, thus a failure will +// be returned if encountering such cases. +LogicalResult getDependentPointers(Value ptr, DenseSet &dependentSet, + DenseSet &processedSet) { + // early return if processed + if (!processedSet.insert(ptr).second) + return success(); + + if (auto blockArg = ptr.dyn_cast()) { + if (!blockArg.getOwner()->isEntryBlock()) + return failure(); + auto parentOp = blockArg.getOwner()->getParentOp(); + if (auto forOp = dyn_cast(parentOp)) { + if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) { + if (failed(getDependentPointers( + forOp.getOpOperandForRegionIterArg(blockArg).get(), + dependentSet, processedSet))) + return failure(); + + unsigned operandIdx = + blockArg.getArgNumber() - forOp.getNumInductionVars(); + return getDependentPointers( + forOp.getBody()->getTerminator()->getOperand(operandIdx), + dependentSet, processedSet); + } + } else if (auto funcOp = dyn_cast(parentOp)) { + dependentSet.insert(ptr); + return success(); + } + // unknown ops, return failure for correctness. + return failure(); + } + + auto definingOp = ptr.getDefiningOp(); + assert(definingOp); + if (auto makeTensorPtrOp = ptr.getDefiningOp()) { + return getDependentPointers(makeTensorPtrOp.getBase(), dependentSet, + processedSet); + } else if (auto advanceOp = ptr.getDefiningOp()) { + return getDependentPointers(advanceOp.getPtr(), dependentSet, processedSet); + } else if (auto addPtrOp = ptr.getDefiningOp()) { + return getDependentPointers(addPtrOp.getPtr(), dependentSet, processedSet); + } else if (auto loadOp = ptr.getDefiningOp()) { + // not support load dependent ptr + return failure(); + } else if (auto forOp = ptr.getDefiningOp()) { + unsigned idx = ptr.cast().getResultNumber(); + return getDependentPointers( + forOp.getBody()->getTerminator()->getOperand(idx), dependentSet, + processedSet); + } else if (auto ifOp = ptr.getDefiningOp()) { + unsigned idx = ptr.cast().getResultNumber(); + if (ifOp.elseBlock() && + failed(getDependentPointers(ifOp.elseYield()->getOperand(idx), + dependentSet, processedSet))) + return failure(); + return getDependentPointers(ifOp.thenYield()->getOperand(idx), dependentSet, + processedSet); + } else if (!definingOp->getNumRegions() || + knownSafeToIgnoreRegion(definingOp)) { + for (Value operand : definingOp->getOperands()) + if (failed(getDependentPointers(operand, dependentSet, processedSet))) + return failure(); + return success(); + } + // unknown ops, return failure for correctness. + return failure(); +} + +// Suppose the kernel has following structure: +// ``` +// scf.for(...) { +// v(i) = load(ptr) +// new_v(i) = some_compute(v(i), ...) +// store(new_v(i), ptr) +// } +// ``` +// +// There is an implicit dependency between load(i+1) and store(i), which means +// we can not pre-compute load(i+1) before store(i). +// +// To avoid such load after store conflict, we simply disallow mixed load and +// store for the same buffer. It's a conservative strategy and can be relaxed in +// case necessary. +bool hasUnsafeLoadAfterStore(triton::FuncOp funcOp) { + // TODO: support CFG + if (funcOp.getBody().getBlocks().size() > 1) + return true; + + DenseMap ptrStoreMap; + DenseMap ptrLoadMap; + if (funcOp + ->walk([&](triton::LoadOp loadOp) { + DenseSet dependentSet, processedSet; + if (failed(getDependentPointers(loadOp.getPtr(), dependentSet, + processedSet))) + return WalkResult::interrupt(); + for (Value v : dependentSet) + ptrLoadMap[v] = true; + return WalkResult::advance(); + }) + .wasInterrupted()) + return false; + auto result = funcOp->walk([&](Operation *op) { + if (auto storeOp = dyn_cast(op)) { + DenseSet dependentSet, processedSet; + if (failed(getDependentPointers(storeOp.getPtr(), dependentSet, + processedSet))) + return WalkResult::interrupt(); + + for (Value v : dependentSet) + ptrStoreMap[v] = true; + + // TODO: relax the restriction in case necessary. + // If a store is inside a region, e.g. scf.while/for/if, its + // dependent ptrs are not allowed to be loaded. + if (op->getParentOp() != funcOp) { + for (Value v : dependentSet) + if (ptrLoadMap.find(v) != ptrLoadMap.end()) + return WalkResult::interrupt(); + } + } else if (auto loadOp = dyn_cast(op)) { + DenseSet dependentSet, processedSet; + if (failed(getDependentPointers(loadOp.getPtr(), dependentSet, + processedSet))) + return WalkResult::interrupt(); + for (Value v : dependentSet) + if (ptrStoreMap.find(v) != ptrStoreMap.end()) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + return result.wasInterrupted(); +} + +bool hasWSCandidateLoad(triton::FuncOp funcOp) { + SmallVector loadOps; + funcOp->walk([&](triton::LoadOp loadOp) { + if (isWSCandidateLoad(loadOp)) + loadOps.push_back(loadOp); + }); + if (loadOps.empty()) + return false; + + // All the candidate ops should be in the same block and have compatible + // types. + Block *block = loadOps[0]->getBlock(); + auto refTy = loadOps[0].getPtr().getType().dyn_cast(); + bool isPtrToTensor = refTy && refTy.getPointeeType().isa(); + for (auto loadOp : loadOps) { + if (loadOp->getBlock() != block) + return false; + // not support mixed ptr to tensor and tensor of ptr currently. + auto ty = loadOp.getPtr().getType().dyn_cast(); + if (isPtrToTensor != (ty && ty.getPointeeType().isa())) + return false; + } + + // S0 = dependent value set of all the candidate ops + // S1 = dependent value set of all the store ops + // S2 = S1 & S0 + // any value in S2 should not be the output of an op having regions. + // TODO: lift the limitation of WSPipeline pass to remove this check. + DenseSet loadDepSet; + DenseSet loadSet; + for (auto op : loadOps) { + if (failed(getDependentValues(op.getOperation(), loadDepSet))) + return false; + loadSet.insert(op->getResult(0)); + } + + DenseSet storeDepSet; + if (funcOp + ->walk([&](triton::StoreOp op) { + if (failed(getDependentValues(op.getOperation(), storeDepSet, + loadSet))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }) + .wasInterrupted()) + return false; + + for (Value v : loadDepSet) + if (storeDepSet.find(v) != storeDepSet.end()) { + auto op = v.getDefiningOp(); + if (op && op->getNumRegions()) + return false; + } + + return true; +} + +} // namespace + +//===----------------------------------------------------------------------===// +// Helper functions for async agent +//===----------------------------------------------------------------------===// + +SmallVector getAgentIds(Operation *op) { + SmallVector agentIds; + if (auto attr = op->getAttrOfType("async_agent")) + for (AgentId agentId : attr.getValues()) + agentIds.push_back(agentId); + return agentIds; +} + +bool hasAgentId(Operation *op, AgentId agentId) { + for (AgentId candidate : getAgentIds(op)) + if (candidate == agentId) + return true; + return false; +} + +void setAgentIds(Operation *op, ArrayRef agentIds) { + SmallVector sortedAgentIds(agentIds.begin(), agentIds.end()); + sort(sortedAgentIds); + auto i32Ty = IntegerType::get(op->getContext(), 32); + auto size = static_cast(sortedAgentIds.size()); + auto vecTy = VectorType::get(size, i32Ty); + op->setAttr("async_agent", DenseIntElementsAttr::get(vecTy, sortedAgentIds)); +} + +SmallVector collectAgentIds(Operation *op) { + SetVector agentIds; + op->walk([&](Operation *curOp) { + for (AgentId agentId : getAgentIds(curOp)) + agentIds.insert(agentId); + }); + SmallVector res(agentIds.begin(), agentIds.end()); + llvm::sort(res); + return res; +} + +void addAgentIds(Operation *op, ArrayRef agents) { + auto agentsVec = getAgentIds(op); + DenseSet agentsSet(agentsVec.begin(), agentsVec.end()); + for (int a : agents) { + if (!agentsSet.contains(a)) { + agentsVec.push_back(a); + } + } + if (agentsVec.size() > 0) { + setAgentIds(op, agentsVec); + } +} + +SmallVector getMutexBarIds(Operation *op) { + SmallVector barIds; + if (auto attr = op->getAttrOfType("mutex.barId")) + for (int id : attr.getValues()) + barIds.push_back(id); + return barIds; +} + +SmallVector getMutexNumThreads(Operation *op) { + SmallVector numThreads; + if (auto attr = op->getAttrOfType("mutex.numThreads")) + for (int n : attr.getValues()) + numThreads.push_back(n); + return numThreads; +} + +//===----------------------------------------------------------------------===// +// Implementations for general auto WS +//===----------------------------------------------------------------------===// + +// Populates `depSet` with the values that `val` depends on and Returns success. +// Returns failure() if encountering any unsupported conditions. +LogicalResult getDependentValues(Value val, DenseSet &depSet, + const DenseSet &stopSet) { + auto tryInsertAndPropagate = [&](Value other) { + if (stopSet.find(other) == stopSet.end() && depSet.insert(other).second) + return getDependentValues(other, depSet, stopSet); + return success(); + }; + auto addControlOperandsForForOp = [&](scf::ForOp forOp) { + for (Value operand : + forOp->getOperands().take_front(forOp.getNumControlOperands())) + if (failed(tryInsertAndPropagate(operand))) + return failure(); + return success(); + }; + auto addControlOperandsForIfOp = [&](scf::IfOp ifOp) { + return tryInsertAndPropagate(ifOp.getCondition()); + }; + auto propagateParentOp = [&](Operation *op) { + while (Operation *parentOp = op->getParentOp()) { + if (auto forOp = dyn_cast(parentOp)) + return addControlOperandsForForOp(forOp); + else if (auto ifOp = dyn_cast(parentOp)) + return addControlOperandsForIfOp(ifOp); + else if (auto funcOp = dyn_cast(parentOp)) + return success(); + else + break; + op = parentOp; + } + // unknown ops, return failure for correctness. + return failure(); + }; + + if (auto blockArg = val.dyn_cast()) { + auto parentOp = blockArg.getOwner()->getParentOp(); + if (auto forOp = dyn_cast(parentOp)) { + // add control operands of forOp into dependent set + if (failed(addControlOperandsForForOp(forOp))) + return failure(); + if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) { + Value operand = forOp.getOpOperandForRegionIterArg(blockArg).get(); + if (failed(tryInsertAndPropagate(operand))) + return failure(); + + unsigned operandIdx = + blockArg.getArgNumber() - forOp.getNumInductionVars(); + return tryInsertAndPropagate( + forOp.getBody()->getTerminator()->getOperand(operandIdx)); + } + return propagateParentOp(parentOp); + } else if (auto funcOp = dyn_cast(parentOp)) { + if (stopSet.find(val) == stopSet.end()) + depSet.insert(val); + return success(); + } else { + // unknown ops, return failure for correctness. + return failure(); + } + } + + auto definingOp = val.getDefiningOp(); + assert(definingOp); + if (auto forOp = val.getDefiningOp()) { + if (failed(addControlOperandsForForOp(forOp))) + return failure(); + unsigned idx = val.cast().getResultNumber(); + if (failed(tryInsertAndPropagate( + forOp->getOperand(idx + forOp.getNumControlOperands())))) + return failure(); + return tryInsertAndPropagate( + forOp.getBody()->getTerminator()->getOperand(idx)); + } else if (auto ifOp = val.getDefiningOp()) { + if (failed(addControlOperandsForIfOp(ifOp))) + return failure(); + unsigned idx = val.cast().getResultNumber(); + if (ifOp.elseBlock() && + failed(tryInsertAndPropagate(ifOp.elseYield()->getOperand(idx)))) + return failure(); + return tryInsertAndPropagate(ifOp.thenYield()->getOperand(idx)); + } else if (!definingOp->getNumRegions() || + knownSafeToIgnoreRegion(definingOp)) { + for (Value operand : definingOp->getOperands()) + if (failed(tryInsertAndPropagate(operand))) + return failure(); + return success(); + } else { + // unknown ops, return failure for correctness. + return failure(); + } + + return propagateParentOp(definingOp); +} + +LogicalResult getDependentValues(Operation *op, DenseSet &depSet, + const DenseSet &stopSet) { + if (op->getNumResults() > 0) { + for (Value result : op->getResults()) + if (failed(getDependentValues(result, depSet, stopSet))) + return failure(); + } else { + // Not support op with regions + if (op->getNumRegions() != 0) + return failure(); + for (Value operand : op->getOperands()) { + if (stopSet.find(operand) != stopSet.end()) + continue; + depSet.insert(operand); + if (failed(getDependentValues(operand, depSet, stopSet))) + return failure(); + } + } + return success(); +} + +DenseSet getDependentOps(DenseSet &depSet) { + DenseSet depOps; + for (Value val : depSet) { + Operation *op = val.getDefiningOp(); + if (auto blockArg = val.dyn_cast()) + op = blockArg.getOwner()->getParentOp(); + + while (op && !isa(op)) { + depOps.insert(op); + op = op->getParentOp(); + } + } + return depOps; +} + +bool isWSCandidateLoad(Operation *op) { + auto loadOp = dyn_cast(op); + if (!loadOp) + return false; + + Value result = loadOp->getResult(0); + auto resultTy = result.getType().cast(); + // Skip those tensors that are too small. + if (resultTy.getNumElements() <= 64) + return false; + // TODO: remove this limit once we refator ws pipeline pass. + if (resultTy.getNumElements() % 128 != 0) + return false; + // pattern match: load + convert_layout(blocked, shared) + if (!result.hasOneUse()) + return false; + auto cvtOp = dyn_cast(*result.getUsers().begin()); + if (!cvtOp) + return false; + auto encoding = + cvtOp.getResult().getType().cast().getEncoding(); + if (!encoding || !encoding.dyn_cast()) + return false; + + DenseSet depSet; + if (failed(getDependentValues(op->getResult(0), depSet))) + return false; + auto depOps = getDependentOps(depSet); + for (Operation *depOp : depOps) { + if (isa(depOp)) + return false; + } + return op->getParentOfType() || + op->getParentOfType(); +} + +bool isWSSupported(ModuleOp mod, int computeCapability) { + // Early return if the target device is not feasible. + if (computeCapability / 10 < 9) { + return false; + } + + // TODO: support function call. + triton::FuncOp funcOp; + if (mod->walk([&](triton::FuncOp op) { + if (funcOp) + return WalkResult::interrupt(); + funcOp = op; + return WalkResult::advance(); + }) + .wasInterrupted() || + !funcOp) + return false; + + // Triton programs with global barrier are much harder to do auto warp + // specialization. Here we do some conservative checks to skip the bad cases. + if (hasUnsafeBarrier(funcOp)) + return false; + + // load after store for the same buffer forces an implicit dependency, which + // may break auto WS. Here we do some conservative checks to skip the bad + // cases. + if (hasUnsafeLoadAfterStore(funcOp)) + return false; + + if (!hasWSCandidateLoad(funcOp)) + return false; + + return true; +} + +} // namespace mlir diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSDecomposing.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSDecomposing.cpp new file mode 100644 index 000000000000..867f0040400b --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSDecomposing.cpp @@ -0,0 +1,260 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +#include "mlir/Analysis/SliceAnalysis.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +using namespace mlir; +namespace ttng = mlir::triton::nvidia_gpu; + +class Heuristics { + //===--------------------------Label rules--------------------------===// + // - Op without agent attr: opeations shared by all agents but will NOT + // be copied into each agent region + // - Op with one agent: exclusive for one agent + // - Op with agents: shared by agents and will be copied into each agent + // region + //===---------------------------------------------------------------===// +public: + Heuristics(MLIRContext *context, ModuleOp mod, const int &computeCapability) + : context(context), mod(mod), computeCapability(computeCapability), + builder(OpBuilder(context)) {} + virtual bool run() = 0; + +protected: + // Set agentId when condition is satisfied + virtual void + setAgentId_if(int agentId, + const std::function &condition) { + mod.walk([&](Operation *op) -> void { + if (condition(op)) { + setAgentIds(op, {agentId}); + } + }); + } + + static bool isTritonLoadOp(Operation *op) { return isa(op); } + + static bool isTritonDotOp(Operation *op) { return isa(op); } + + static bool isTritonStoreOp(Operation *op) { + return isa(op); + } + + /// Becuase we set some special filter rules in populateAgentRegion, + /// there may be unlabeled Ops, e.g. YieldOps, some definingOps of ForOps. + /// or Ops without relations to agentOps + virtual void populateUnlabledOpsAtLast(ArrayRef allAgents) { + // Label agents' parentOps + for (int i : allAgents) { + DenseSet agentParentOps; + getAllParentOps(agentParentOps, i); + for (auto op : agentParentOps) { + addAgentIds(op, {i}); + } + } + + // Get unlabeled Ops + DenseSet unlabeledOps; + mod.walk([&](Operation *op) -> void { + if (isa(op) || isa(op) || + isa(op)) { + return; + } + if (!op->hasAttr("async_agent")) { + unlabeledOps.insert(op); + } + }); + + // Label Ops using its parentOp + for (auto op : unlabeledOps) { + if (auto parent = op->getParentOp()) { + if (!isa(parent)) { + assert(parent->hasAttr("async_agent")); + auto agents = getAgentIds(parent); + setAgentIds(op, agents); + unlabeledOps.erase(op); + } + } + } + + // Label Ops using dependency + for (auto op : unlabeledOps) { + labelByUsers(op, allAgents); + unlabeledOps.erase(op); + } + assert(unlabeledOps.size() == 0); + } + + // Return all Ops that are marked with target agent + void getAgentOps(DenseSet &agentOps, int agentId) { + SmallVector tmpArray{agentId}; + auto agentAttr = builder.getI32VectorAttr(ArrayRef(tmpArray)); + mod.walk([&](Operation *op) -> void { + if (op->hasAttr("async_agent") && + op->getAttr("async_agent") == agentAttr) { + agentOps.insert(op); + } + }); + } + + void getAllParentOps(DenseSet &parentOps, int agentId) { + DenseSet targetOps; + getAgentOps(targetOps, agentId); + for (auto op : targetOps) { + getAllParentOps(parentOps, op); + } + } + + void getAllParentOps(DenseSet &parentOps, Operation *targetOp) { + auto op = targetOp; + while (auto parent = op->getParentOp()) { + if (!isa(parent) && !isa(parent)) { + parentOps.insert(parent); + op = parent; + } else { + break; + } + } + } + + void labelByUsers(Operation *op, ArrayRef allAgents) { + for (Value result : op->getResults()) { + for (Operation *userOp : result.getUsers()) { + if (!userOp->hasAttr("async_agent")) { + labelByUsers(userOp, allAgents); + } + addAgentIds(op, getAgentIds(userOp)); + } + } + if (!op->hasAttr("async_agent")) { + addAgentIds(op, allAgents); + } + } + +protected: + MLIRContext *context; + ModuleOp mod; + int computeCapability; + OpBuilder builder; +}; + +//===------------------------heuristics list------------------------===// +// List all heuristics here: +// - Heuristic_Load_MathStore: assign load and math+store to two +// different agents respectively. +//===---------------------------------------------------------------===// + +class Heuristic_Load_MathStore : public Heuristics { +public: + Heuristic_Load_MathStore(MLIRContext *context, ModuleOp mod, + const int &computeCapability) + : Heuristics(context, mod, computeCapability) {} + bool run() override { + constexpr int kLoadAgentId = 0; + constexpr int kStoreAgentId = 1; + constexpr int kNumAgents = 2; + + //===--------------------1. label key operations--------------------===// + setAgentId_if(kLoadAgentId, isWSCandidateLoad); + setAgentId_if(kStoreAgentId, isTritonStoreOp); + + //===--------------2. populate based on key operations--------------===// + // find the roots (outputs) of LoadAgent + DenseSet loadOps; + getAgentOps(loadOps, kLoadAgentId); + // find LoadAgent dependent ops + DenseSet loadValues; + DenseSet loadAgentDepValues; + for (Operation *op : loadOps) { + if (failed(getDependentValues(op, loadAgentDepValues))) + return false; + loadValues.insert(op->getResult(0)); + } + for (Operation *op : getDependentOps(loadAgentDepValues)) + addAgentIds(op, kLoadAgentId); + + // find the roots (outputs) of StoreAgent + DenseSet storeOps; + getAgentOps(storeOps, kStoreAgentId); + // find StoreAgent dependent ops + DenseSet storeAgentDepValues; + for (Operation *op : storeOps) + if (failed(getDependentValues(op, storeAgentDepValues, loadValues))) + return false; + for (Operation *op : getDependentOps(storeAgentDepValues)) + addAgentIds(op, kStoreAgentId); + + //===---------------------3. label unlabeld Ops---------------------===// + populateUnlabledOpsAtLast({kLoadAgentId, kDotAgentId}); + + // Erase labels of MakeTensorPtrOp and its definingOps, + // because we don't want them to be copied in each agent + SetVector backwardSlice; + mod.walk([&](triton::MakeTensorPtrOp op) -> void { + assert(isa(op->getParentOp())); + getBackwardSlice(op.getOperation(), &backwardSlice); + op->removeAttr("async_agent"); + }); + for (auto op : backwardSlice) { + op->removeAttr("async_agent"); + } + // Set num-agents for wsmaterialization pass + mod->setAttr("async.num-agents", builder.getI32IntegerAttr(kNumAgents)); + return true; + } +}; + +class TritonGPUWSDecomposingPass + : public TritonGPUWSDecomposingBase { +public: + TritonGPUWSDecomposingPass() = default; + TritonGPUWSDecomposingPass(int computeCapability) { + this->computeCapability = computeCapability; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + if (!ttng::TritonNvidiaGPUDialect::getWSSupportedAttr(mod)) + return signalPassFailure(); + + // Build Heuristics + Heuristic_Load_MathStore hLoadMathBasic(context, mod, computeCapability); + if (!(hLoadMathBasic.run())) { + return signalPassFailure(); + } + } +}; + +std::unique_ptr +mlir::createTritonNvidiaGPUWSDecomposingPass(int computeCapability) { + return std::make_unique(computeCapability); +} diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSFeasibilityChecking.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSFeasibilityChecking.cpp new file mode 100644 index 000000000000..15f19b889113 --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSFeasibilityChecking.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace mlir { + +namespace ttng = triton::nvidia_gpu; + +namespace { + +class TritonGPUWSFeasibilityCheckingPass + : public TritonGPUWSFeasibilityCheckingBase< + TritonGPUWSFeasibilityCheckingPass> { +public: + TritonGPUWSFeasibilityCheckingPass() = default; + TritonGPUWSFeasibilityCheckingPass(int computeCapability) { + this->computeCapability = computeCapability; + } + + void runOnOperation() override { + ModuleOp mod = getOperation(); + int wsSupported = isWSSupported(mod, this->computeCapability); + auto i32_ty = IntegerType::get(mod->getContext(), 32); + mod->setAttr(ttng::TritonNvidiaGPUDialect::getWSSupportedAttrName(), + IntegerAttr::get(i32_ty, llvm::APInt(32, wsSupported))); + } +}; + +} // namespace + +std::unique_ptr +createTritonNvidiaGPUWSFeasibilityCheckingPass(int computeCapability) { + return std::make_unique( + computeCapability); +} + +} // namespace mlir diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSFixupMissingAttrs.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSFixupMissingAttrs.cpp new file mode 100644 index 000000000000..6cd6df7484c9 --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSFixupMissingAttrs.cpp @@ -0,0 +1,69 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace mlir { + +namespace ttng = triton::nvidia_gpu; + +namespace { + +class TritonGPUWSFixupMissingAttrsPass + : public TritonGPUWSFixupMissingAttrsBase< + TritonGPUWSFixupMissingAttrsPass> { +public: + TritonGPUWSFixupMissingAttrsPass() = default; + + void runOnOperation() override { + ModuleOp mod = getOperation(); + if (!ttng::TritonNvidiaGPUDialect::getWSSupportedAttr(mod)) + return; + OpBuilder builder(mod); + mod->walk([&](mlir::triton::FuncOp funcOp) { + for (Operation &op : funcOp.getBody().front().getOperations()) { + if (!isa(&op)) + continue; + auto agentIds = getAgentIds(&op); + if (agentIds.size() != 1) + continue; + Block *roleIdBlock = nullptr; + op.walk([&](Operation *subOp) { + setAgentIds(subOp, agentIds); + // Find the outter most common block that has roleId. + // The below implementation assumes that: + // - all lock/unlock ops are in the same block (denoted as B). + // - there is always one scf.if op in the front of `B` which has + // role id attached. + // The above assumptions are maintained by WSMutex pass currently. + if (!roleIdBlock && isa(subOp) && getWSRoleId(subOp)) + roleIdBlock = subOp->getBlock(); + }); + if (!roleIdBlock) + continue; + int roleId = 0; + for (Operation &roleOp : roleIdBlock->getOperations()) { + auto optionalRoleId = getWSRoleId(&roleOp); + if (!optionalRoleId) { + setRoleId(&roleOp, roleId); + } else { + roleId = *optionalRoleId; + } + roleOp.walk([&](Operation *subOp) { setRoleId(subOp, roleId); }); + } + } + }); + } +}; + +} // namespace + +std::unique_ptr createTritonNvidiaGPUWSFixupMissingAttrs() { + return std::make_unique(); +} + +} // namespace mlir diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMaterialization.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMaterialization.cpp new file mode 100644 index 000000000000..9ebc78497cbb --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMaterialization.cpp @@ -0,0 +1,742 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +#include "mlir/IR/OperationSupport.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" + +#include + +using namespace mlir; +namespace ttg = triton::gpu; +namespace ttng = triton::nvidia_gpu; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +enum class LoadType { + Uninitialized, + InsertSliceAsyncOp, + InsertSliceAsyncV2Op, + MultiKinds, +}; + +// This helper function returns the real threadId while ttng::GetThreadIdOp is +// actually threadId % 128 when warp specialization is enabled +Value getThreadId(OpBuilder &builder, Location loc) { + Value threadId = builder.create<::mlir::gpu::ThreadIdOp>( + loc, builder.getIndexType(), ::mlir::gpu::Dimension::x); + auto cast = builder.create( + loc, TypeRange{builder.getIntegerType(32)}, ValueRange{threadId}); + return cast.getResult(0); +} + +//===----------------------------------------------------------------------===// +// Materialize GetAgentIdOp +//===----------------------------------------------------------------------===// + +void materializeGetAgentIdOp(Operation *parentOp) { + parentOp->walk([](ttng::GetAgentIdOp op) { + // In Hopper, each agent is a warpgroup consisting with 4 warps. + auto loc = op.getLoc(); + OpBuilder builder(op); + + Value _128 = builder.create(loc, 128, 32); + Value threadId = getThreadId(builder, loc); + Value agentId = builder.create(loc, threadId, _128); + op.getResult().replaceAllUsesWith(agentId); + op->erase(); + + // Update agent condition and insert "agent.num-warps" + auto agentIdOp = agentId.getDefiningOp(); + builder.setInsertionPoint(agentIdOp); + Value globalRoleId = builder.create(loc, 0, 32); + int globalNumWarps = 0; + for (auto cmpOp : agentIdOp->getUsers()) { + assert(isa(cmpOp)); + for (auto u : cmpOp->getUsers()) { + if (isa(u) && isa(u->getParentOp()) && + u->hasAttr("async_agent") && getAgentIds(u).size() == 1) { + loc = u->getLoc(); + builder.setInsertionPoint(u); + int numRoles = 1; + if (u->hasAttr("agent.num-roles")) { + numRoles = + u->getAttrOfType("agent.num-roles").getInt(); + // TODO: more flexible ways to get numWarps. + auto numWarps = builder.getI32IntegerAttr(4 * numRoles); + auto numWarpsBase = builder.getI32IntegerAttr(globalNumWarps); + u->setAttr("agent.num-warps", numWarps); + u->walk([&](ttng::GetMutexRoleIdOp roleIdOp) { + roleIdOp->setAttr("agent.num-warps", numWarps); + roleIdOp->setAttr("agent.num-warps-base", numWarpsBase); + }); + } + globalNumWarps += numRoles * 4; + Value offset = + builder.create(loc, numRoles, 32); + Value lowerBound = builder.create( + loc, arith::CmpIPredicate::uge, agentId, globalRoleId); + globalRoleId = + builder.create(loc, globalRoleId, offset); + Value upperBound = builder.create( + loc, arith::CmpIPredicate::ult, agentId, globalRoleId); + Value cond = + builder.create(loc, lowerBound, upperBound); + cmpOp->getResult(0).replaceAllUsesWith(cond); + cmpOp->erase(); + break; + } + } + } + }); +} + +//===----------------------------------------------------------------------===// +// Materialize token operations +//===----------------------------------------------------------------------===// + +LoadType scanLoadTypes(ttng::CreateTokenOp createTokenOp) { + // TODO: Attach information of binded tensors to CreateTokenOp + std::set loadTypes; + createTokenOp->getBlock()->walk([&](Operation *op) { + if (auto insertOp = dyn_cast(op)) { + if (triton::isTensorPointerType(insertOp.getSrc().getType())) + loadTypes.insert(LoadType::InsertSliceAsyncV2Op); + else + loadTypes.insert(LoadType::InsertSliceAsyncOp); + } else if (isa(op)) { + loadTypes.insert(LoadType::InsertSliceAsyncOp); + } else if (isa(op)) { + loadTypes.insert(LoadType::InsertSliceAsyncV2Op); + } + }); + assert(loadTypes.size() > 0 && "InsertSliceOp not found"); + assert(loadTypes.size() == 1 && + "Multiple kinds of load types are not supported"); + return *loadTypes.begin(); +} + +Value getMBarrierPhaseBit(OpBuilder &builder, Operation *op, + bool skipFirstWait) { + // TODO: currently we only support one loop, no nested loop, while or + // condition. + auto loc = op->getLoc(); + auto forOp = op->getParentOfType(); + if (!forOp) { + return builder.create(loc, skipFirstWait, 1); + } + + auto defOp = op->getOperand(0).getDefiningOp(); + assert(isa(defOp) && + "mbarrier's definingOp is not createTokenOp"); + ttng::CreateTokenOp createTokenOp = dyn_cast(defOp); + Value numStage = + builder.create(loc, createTokenOp.getNum(), 32); + Value curStep = forOp.getBody()->getArguments().back(); + if (curStep.getType() == builder.getIndexType()) { + curStep = + builder.create(loc, numStage.getType(), curStep); + } + Value curPhase = builder.create(loc, curStep, numStage); + if (skipFirstWait) { + // If skipFirstWait, it waits for phaseBit 1 + Value _1 = builder.create(loc, 1, 32); + curPhase = builder.create(loc, curPhase, _1); + } + Value _2 = builder.create(loc, 2, 32); + // TODO: May use alternative methods of phaseBit calculation to avoid high + // overhead of RemOp + Value phaseBit = builder.create(loc, curPhase, _2); + Value _0 = builder.create(loc, 0, 32); + return builder.create(loc, arith::CmpIPredicate::ne, phaseBit, + _0); +} + +int getTxBytes(ttng::InsertSliceAsyncV2Op load) { + // Both support ptr of tensor and tensor of ptr. + RankedTensorType srcTensorType; + if (auto srcType = dyn_cast(load.getSrc().getType())) { + srcTensorType = srcType; + } else if (auto srcType = + dyn_cast(load.getSrc().getType())) { + srcTensorType = dyn_cast(srcType.getPointeeType()); + } else { + llvm_unreachable("Unexpected src type"); + } + auto shapePerCTA = ttg::getShapePerCTA(srcTensorType); + auto elemTy = + dyn_cast(load.getDst().getType()).getElementType(); + int bytesPerElem = elemTy.getIntOrFloatBitWidth() / 8; + return product(shapePerCTA) * bytesPerElem; +} + +int applyCommit(OpBuilder &builder, ttng::ProducerCommitOp &op, + Value mbarrier) { + // TODO: currently it only handles loads in ProducerCommitOp's nearest parent + // block. Neither support multiple ProducerCommitOp, e.g. fused attention, + // epilogue fusion. + int txCnt = 0; + SmallVector deprecatedOps; + auto agentIds = getAgentIds(op); + // Materialize InsertSliceOp + for (auto &ItrOp : op->getBlock()->getOperations()) { + // Check operations before ProducerCommitOp + if (OperationEquivalence::isEquivalentTo(&ItrOp, op.getOperation(), + OperationEquivalence::None)) { + break; + } + if (auto insertOp = dyn_cast(ItrOp)) { + deprecatedOps.push_back(&ItrOp); + builder.setInsertionPoint(insertOp); + if (!::mlir::triton::isTensorPointerType(insertOp.getSrc().getType())) { + // Transform to InsertSliceAsyncOp + auto newSliceOp = builder.create( + /*loc=*/insertOp.getLoc(), /*result=*/insertOp.getDst().getType(), + /*src=*/insertOp.getSrc(), /*dst=*/insertOp.getDst(), + /*index=*/insertOp.getIndex(), + /*mask=*/insertOp.getMask(), /*other=*/insertOp.getOther(), + /*cache=*/insertOp.getCache(), /*evict=*/insertOp.getEvict(), + /*isVolatile=*/insertOp.getIsVolatile(), + /*axis=*/insertOp.getAxis()); + insertOp.getResult().replaceAllUsesWith(newSliceOp.getResult()); + setAgentIds(newSliceOp, agentIds); + } else { + // Transform to InsertSliceAsyncV2Op + auto extractBarrierOp = dyn_cast( + builder.clone(*(mbarrier.getDefiningOp()))); + auto newSliceOp = builder.create( + /*loc=*/insertOp.getLoc(), /*result=*/insertOp.getDst().getType(), + /*src=*/insertOp.getSrc(), /*dst=*/insertOp.getDst(), + /*index=*/insertOp.getIndex(), + /*mbar*/ extractBarrierOp.getResult(), /*mask=*/insertOp.getMask(), + /*other=*/insertOp.getOther(), + /*cache=*/insertOp.getCache(), /*evict=*/insertOp.getEvict(), + /*isVolatile=*/insertOp.getIsVolatile(), + /*axis=*/insertOp.getAxis()); + insertOp.getResult().replaceAllUsesWith(newSliceOp.getResult()); + setAgentIds(newSliceOp, agentIds); + txCnt += getTxBytes(newSliceOp); + } + } + } + builder.setInsertionPoint(op); + for (auto d : deprecatedOps) { + d->erase(); + } + + return txCnt; +} + +void processProducerAcquireOp(OpBuilder &builder, ttng::ProducerAcquireOp op, + Value bufferEmpty) { + auto loc = op.getLoc(); + // The first producer_aquire should be met immediately, so initailly producer + // skips the fisrt wait + Value phase = getMBarrierPhaseBit(builder, op, 1); + auto waitOp = builder.create(loc, bufferEmpty, phase); + assert(op.getOperation()->hasAttr("async_agent")); + setAgentIds(waitOp, getAgentIds(op.getOperation())); +} + +void processProducerCommitOp(OpBuilder &builder, ttng::ProducerCommitOp op, + Value bufferFull, LoadType loadType) { + auto loc = op.getLoc(); + int txCnt = applyCommit(builder, op, bufferFull); + ttng::MBarrierArriveOp arriveOp; + + if (loadType == LoadType::InsertSliceAsyncOp) { + // Each thread arrives + Value pred = builder.create(loc, 1, 1); + arriveOp = builder.create( + loc, bufferFull, pred, /*remoteCTAId*/ nullptr, /*trackAsyncOp*/ true, + txCnt); + } else { + // Only thread 0 arrives + Value _0 = builder.create(loc, 0, 32); + Value threadId = getThreadId(builder, loc); + Value pred = builder.create(loc, arith::CmpIPredicate::eq, + threadId, _0); + arriveOp = builder.create( + loc, bufferFull, pred, /*remoteCTAId*/ nullptr, /*trackAsyncOp*/ false, + txCnt); + } + + assert(op.getOperation()->hasAttr("async_agent")); + setAgentIds(arriveOp, getAgentIds(op.getOperation())); +} + +void processConsumerWaitOp(OpBuilder &builder, ttng::ConsumerWaitOp op, + Value bufferFull) { + auto loc = op.getLoc(); + Value phase = getMBarrierPhaseBit(builder, op, 0); + auto waitOp = builder.create(loc, bufferFull, phase); + assert(op.getOperation()->hasAttr("async_agent")); + setAgentIds(waitOp, getAgentIds(op.getOperation())); +} + +void processConsumerReleaseOp(OpBuilder &builder, ttng::ConsumerReleaseOp op, + Value bufferEmpty, int numCTAs) { + auto loc = op.getLoc(); + + // Constants + Value _0 = builder.create(loc, 0, 32); + Value _4 = builder.create(loc, 4, 32); + Value _8 = builder.create(loc, 8, 32); + Value _32 = builder.create(loc, 32, 32); + Value _128 = builder.create(loc, 128, 32); + + // threadId = threadId % 128 + Value threadId = + builder.create(loc, getThreadId(builder, loc), _128); + + // k = threadId / 8 + Value k = builder.create(loc, threadId, _8); + + // row = k / 4 + Value row = builder.create(loc, k, _4); + + // col = k % 4 + Value col = builder.create(loc, k, _4); + + // remoteCTAId = (col ^ row) * 4 + col + Value remoteCTAId = builder.create( + loc, + Value{builder.create( + loc, Value{builder.create(loc, col, row)}, _4)}, + col); + + // pred0 = threadId % 8 == 0 + Value pred0 = builder.create( + loc, arith::CmpIPredicate::eq, + builder.create(loc, threadId, _8), _0); + + // pred1 = remoteCTAId < numCTAs + Value pred1 = builder.create( + loc, arith::CmpIPredicate::ult, remoteCTAId, + builder.create(loc, numCTAs, 32)); + + // pred = pred0 & pred1 + Value pred = builder.create(loc, pred0, pred1); + + // bufferEmpty arrive + auto arriveOp = builder.create(loc, bufferEmpty, pred, + remoteCTAId, false, 0); + + assert(op.getOperation()->hasAttr("async_agent")); + setAgentIds(arriveOp, getAgentIds(op.getOperation())); +} + +void materializeTokenOperations(Operation *parentOp, int numCTAs) { + SmallVector deprecatedOps; + parentOp->walk([&](ttng::CreateTokenOp createTokenOp) { + // Scan load type + LoadType loadType = scanLoadTypes(createTokenOp); + + // mBarrierTy + MLIRContext *context = createTokenOp.getContext(); + auto i64Ty = IntegerType::get(context, 64); + auto mBarrierTy = triton::PointerType::get(i64Ty, 3); + + // mBarriersTy + auto CTALayout = ttg::CTALayoutAttr::get(context, {1}, {1}, {0}); + auto sharedLayout = + ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, CTALayout, false); + auto mBarriersTy = + RankedTensorType::get({createTokenOp.getNum()}, i64Ty, sharedLayout); + + // Process CreateTokenOp + OpBuilder builder(createTokenOp); + auto tokenLoc = createTokenOp.getLoc(); + unsigned bufferFullCount = + loadType == LoadType::InsertSliceAsyncV2Op ? 1 : 128; + Value bufferFullArray = builder.create( + tokenLoc, mBarriersTy, bufferFullCount); + Value bufferEmptyArray = + builder.create(tokenLoc, mBarriersTy, numCTAs); + + if (numCTAs == 1) { + builder.create(tokenLoc); + } else { + // Make sure that MBarriers are initialized in all CTAs + builder.create(tokenLoc, false); + builder.create(tokenLoc); + } + + // Helper function for extracting bufferFull + auto extractBufferFull = [&](Location loc, Value idx) -> Value { + return builder.create(loc, mBarrierTy, + bufferFullArray, idx); + }; + + // Helper function for extracting bufferEmpty + auto extractBufferEmpty = [&](Location loc, Value idx) -> Value { + return builder.create(loc, mBarrierTy, + bufferEmptyArray, idx); + }; + + // Process token users + for (Operation *user : createTokenOp.getResult().getUsers()) { + auto loc = user->getLoc(); + builder.setInsertionPoint(user); + if (auto op = dyn_cast(user)) { + Value bufferEmpty = extractBufferEmpty(loc, op.getIdx()); + assert(user->hasAttr("async_agent")); + setAgentIds(bufferEmpty.getDefiningOp(), getAgentIds(user)); + processProducerAcquireOp(builder, op, bufferEmpty); + } else if (auto op = dyn_cast(user)) { + Value bufferFull = extractBufferFull(loc, op.getIdx()); + assert(user->hasAttr("async_agent")); + setAgentIds(bufferFull.getDefiningOp(), getAgentIds(user)); + processProducerCommitOp(builder, op, bufferFull, loadType); + } else if (auto op = dyn_cast(user)) { + Value bufferFull = extractBufferFull(loc, op.getIdx()); + assert(user->hasAttr("async_agent")); + setAgentIds(bufferFull.getDefiningOp(), getAgentIds(user)); + processConsumerWaitOp(builder, op, bufferFull); + } else if (auto op = dyn_cast(user)) { + Value bufferEmpty = extractBufferEmpty(loc, op.getIdx()); + assert(user->hasAttr("async_agent")); + setAgentIds(bufferEmpty.getDefiningOp(), getAgentIds(user)); + processConsumerReleaseOp(builder, op, bufferEmpty, numCTAs); + } else { + llvm_unreachable("Unexpected user of token"); + } + deprecatedOps.push_back(user); + } + + deprecatedOps.push_back(createTokenOp); + }); + for (auto op : deprecatedOps) { + op->erase(); + } + + // Insert a cluster barrier before the kernel exits. Without this barrier, + // mbarrier_remote_arrive will fail if the remote CTA already exits. + if (numCTAs > 1) { + parentOp->walk([&](triton::FuncOp funcOp) { + Block *block = &funcOp.getBody().front(); + auto returnOp = llvm::cast(block->getTerminator()); + OpBuilder builder(returnOp); + auto loc = returnOp.getLoc(); + builder.create(loc, false); + builder.create(loc); + }); + } +} + +//===----------------------------------------------------------------------===// +// Materialize mutex operations +//===----------------------------------------------------------------------===// + +void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId, + int &globalNumRoles) { + // ping-pang mutex sync: using named barrier and only suitable for two roles. + // Take mutex syncronization between dot and store as an example: + // * For dot loop: + // * role 0 waits for named barrier 15 (loop enter), arrives named barrier + // 14 (loop leave) + // * role 1 waits for named barrier 14 (loop enter), arrives named barrier + // 15 (loop leave) + // * For store: + // * role 0 waits for named barrier 13 (store enter), arrives named barrier + // 12 (store leave) + // * role 1 waits for named barrier 12 (store enter), arrives named barrier + // 13 (store leave) + // As number of named barriers is limited (16), theoretically this mechanism + // only support few roles and agents. + int numRoles = 2, times = 0; + globalNumRoles += numRoles; + Value roleId; + parentOp->walk([&](ttng::GetMutexRoleIdOp getMutexRoleIdOp) { + // GetMutexRoleIdOp only occures once. + assert(times == 0); + OpBuilder builder(getMutexRoleIdOp); + numRoles = getMutexRoleIdOp.getNum(); + auto loc = getMutexRoleIdOp->getLoc(); + Value threadId = getThreadId(builder, loc); + assert(getMutexRoleIdOp->hasAttr("agent.num-warps")); + int numThreads = + 32 * getMutexRoleIdOp->getAttrOfType("agent.num-warps") + .getInt(); + int numThreadsBase = + 32 * + getMutexRoleIdOp->getAttrOfType("agent.num-warps-base") + .getInt(); + assert(numThreads % numRoles == 0); + // TODO: more flexible ways to determine numWarps of each agent. + Value numThreadsValue = + builder.create(loc, numThreads, 32); + Value numRolesValue = + builder.create(loc, numRoles, 32); + Value numThreadsBaseValue = + builder.create(loc, numThreadsBase, 32); + Value numThreadsPerRole = + builder.create(loc, numThreadsValue, numRolesValue); + Value numRemThreads = + builder.create(loc, threadId, numThreadsBaseValue); + roleId = + builder.create(loc, numRemThreads, numThreadsPerRole); + getMutexRoleIdOp.getResult().replaceAllUsesWith(roleId); + getMutexRoleIdOp->erase(); + times++; + }); + + parentOp->walk([&](ttng::CreateMutexOp createMutexOp) { + // Currently, inner-agent sync counts from barId 1 (see membar.cpp, bar 0 + // is used for whole block sync). + // We need to guarantee mutex sync won't use bars of inner-agent sync. + assert(nameBarrierId > globalNumRoles); + // Process CreateMutexOp + OpBuilder builder(createMutexOp); + // TODO: change the hard code of numThreads + auto loc = createMutexOp->getLoc(); + Value numThreads = builder.create(loc, 256, 32); + Value _0 = builder.create(loc, 0, 32); + Value isRole0 = builder.create(loc, arith::CmpIPredicate::eq, + roleId, _0); + assert(nameBarrierId < nameBarrierIdEnd && + nameBarrierId - 1 >= nameBarrierIdBegin); + Value namedBarrierId0 = + builder.create(loc, nameBarrierId, 32); + Value namedBarrierId1 = + builder.create(loc, nameBarrierId - 1, 32); + // Process mutex users + int numUsers = 0; + for (Operation *user : createMutexOp.getResult().getUsers()) { + numUsers++; + assert(numUsers <= 2); + auto loc = user->getLoc(); + builder.setInsertionPoint(user); + if (auto op = dyn_cast(user)) { + Value barEnter = builder.create( + loc, isRole0, namedBarrierId0, namedBarrierId1); + builder.create(loc, barEnter, numThreads); + } else if (auto op = dyn_cast(user)) { + Value barLeave = builder.create( + loc, isRole0, namedBarrierId1, namedBarrierId0); + builder.create(loc, barLeave, numThreads); + } else + llvm_unreachable("Unexpected user of mutex"); + user->erase(); + } + nameBarrierId -= 2; + nameBarrierIdEnd -= 2; + createMutexOp.erase(); + }); +} + +void processLockOp(OpBuilder &builder, ttng::LockOp op) { + auto loc = op.getLoc(); + assert(op->hasAttr("mutex.barId") && op->hasAttr("mutex.numThreads")); + auto barIds = getMutexBarIds(op); + auto threads = getMutexNumThreads(op); + assert(barIds.size() > 0 && barIds.size() == threads.size()); + for (int i = 0; i < barIds.size(); ++i) { + Value numThreads = + builder.create(loc, threads[i], 32); + Value barrier = builder.create(loc, barIds[i], 32); + builder.create(loc, barrier, numThreads); + } +} + +void processUnlockOp(OpBuilder &builder, ttng::UnlockOp op) { + auto loc = op.getLoc(); + assert(op->hasAttr("mutex.barId") && op->hasAttr("mutex.numThreads")); + auto barIds = getMutexBarIds(op); + auto threads = getMutexNumThreads(op); + assert(barIds.size() > 0 && barIds.size() == threads.size()); + for (int i = 0; i < barIds.size(); ++i) { + Value numThreads = + builder.create(loc, threads[i], 32); + Value barrier = builder.create(loc, barIds[i], 32); + builder.create(loc, barrier, numThreads); + } +} + +void materializeMutexOperationsOthers(ModuleOp parentOp) { + parentOp->walk([](ttng::CreateMutexOp createMutexOp) { + // Process CreateMutexOp + OpBuilder builder(createMutexOp); + + // Process mutex users + for (Operation *user : createMutexOp.getResult().getUsers()) { + auto loc = user->getLoc(); + builder.setInsertionPoint(user); + if (auto op = dyn_cast(user)) + processLockOp(builder, op); + else if (auto op = dyn_cast(user)) + processUnlockOp(builder, op); + else + llvm_unreachable("Unexpected user of mutex"); + user->erase(); + } + + createMutexOp.erase(); + }); +} + +void materializeMutexOperations(ModuleOp parentOp) { + nameBarrierIdEnd = 16; + int nameBarrierId = 15; + int globalNumRoles = 0; + // Materialize mutex operations from WSMutex, i.e. auto-mutex + parentOp->walk([&](scf::IfOp IfOp) { + int numRoles = 0; + if (IfOp->hasAttr("agent.num-roles")) { + assert(parentOp->hasAttr("async.num-agents")); + int numAgents = + parentOp->getAttrOfType("async.num-agents").getInt(); + numRoles = IfOp->getAttrOfType("agent.num-roles").getInt(); + // TODO: To support arbitrary number of roles, use mbarrier. + assert(numRoles == 2); + mutexSyncPingPang(IfOp, numAgents, nameBarrierId, globalNumRoles); + } + }); + // Materialize mutex operations for remaining cases. + // User needs to guarantee correctness of synchronization. + materializeMutexOperationsOthers(parentOp); +} + +// TODO: may also not support 8-warp kernel. +void tryRegisterRealloc(ModuleOp mod) { + constexpr int LoadRegisterRequirement = 40; + constexpr int MmaRegisterRequirement = 232; + OpBuilderWithAgentIds builder(mod.getContext()); + + auto isLoadAgent = [](scf::IfOp ifOp) -> bool { + return ifOp + ->walk([](Operation *op) { + if (isa(op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }) + .wasInterrupted(); + }; + + auto isMmaAgent = [](scf::IfOp ifOp) -> bool { + return ifOp + ->walk([](Operation *op) { + if (isa(op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }) + .wasInterrupted(); + }; + + // TODO: we need to make agent info more handy + SmallVector agentOps; + mod->walk([&agentOps](triton::FuncOp funcOp) { + Block *block = &funcOp.getBody().front(); + for (Operation &op : block->getOperations()) { + if (auto ifOp = dyn_cast(&op)) { + if (getAgentIds(ifOp).size() == 1) { + agentOps.push_back(ifOp); + } + } + } + }); + for (auto ifOp : agentOps) { + builder.setInsertionPointToStart(&(ifOp.getThenRegion().front())); + builder.setAgentIdsFromOp(ifOp); + auto loc = ifOp.getLoc(); + Type i32_ty = builder.getIntegerType(32); + // If an agent has both mma and load, do nothing. + if (isMmaAgent(ifOp) && isLoadAgent(ifOp)) + continue; + if (isMmaAgent(ifOp)) { + builder.createWithAgentIds( + loc, builder.getIntegerAttr(i32_ty, MmaRegisterRequirement)); + } else if (isLoadAgent(ifOp)) { + builder.createWithAgentIds( + loc, builder.getIntegerAttr(i32_ty, LoadRegisterRequirement)); + } + } +} + +//===----------------------------------------------------------------------===// +// WSMaterializationPass +//===----------------------------------------------------------------------===// + +struct WSMaterializationPass + : public TritonGPUWSMaterializationBase { + WSMaterializationPass() = default; + WSMaterializationPass(int computeCapability) { + this->computeCapability = computeCapability; + } + + void runOnOperation() override { + ModuleOp mod = getOperation(); + if (!ttng::TritonNvidiaGPUDialect::getWSSupportedAttr(mod)) + return signalPassFailure(); + + if (computeCapability / 10 < 9) { + llvm_unreachable("WSMaterialization pass only supports sm_9x as of now."); + signalPassFailure(); + } + + int numCTAs = ttg::TritonGPUDialect::getNumCTAs(mod); + + materializeGetAgentIdOp(mod); + materializeTokenOperations(mod, numCTAs); + materializeMutexOperations(mod); + tryRegisterRealloc(mod); + + // TODO: More flexible way to set num-warps + // One dma, one math warp group, set num-warps = 8 + auto i32_ty = IntegerType::get(mod->getContext(), 32); + mod->setAttr("triton_gpu.num-warps", + IntegerAttr::get(i32_ty, llvm::APInt(32, 8))); + + WalkResult result = mod->walk([&](scf::IfOp ifOp) { + if (ifOp->hasAttr("agent.num-roles")) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (result.wasInterrupted()) { + mod->setAttr("triton_gpu.num-warps", + IntegerAttr::get(i32_ty, llvm::APInt(32, 12))); + } + mod->removeAttr("async.num-agents"); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// createTritonNvidiaGPUWSMaterializationPass +//===----------------------------------------------------------------------===// + +std::unique_ptr +mlir::createTritonNvidiaGPUWSMaterializationPass(int computeCapability) { + return std::make_unique(computeCapability); +} diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp new file mode 100644 index 000000000000..a7bab9ff1366 --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp @@ -0,0 +1,316 @@ +#include + +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" + +#include "mlir/Analysis/SliceAnalysis.h" + +using namespace mlir; +namespace ttg = triton::gpu; +namespace ttng = triton::nvidia_gpu; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +// Target operations: dot, load, store. Add more when necessary. +#define KEY_TYPES triton::DotOp, ttg::InsertSliceOp, triton::StoreOp + +template +void getKeyTypeId(Operation *op, int &id, bool &found) { + if (isa(op)) + found = true; + + if (!found) { + id++; + if constexpr (sizeof...(Tails) > 0) + getKeyTypeId(op, id, found); + } +} + +template int getKeyTypeIdWrapper(Operation *op) { + bool found = false; + int id = 0; + getKeyTypeId(op, id, found); + return found ? id : -1; +} + +bool isEligible(Operation *agent, + DenseMap> &keyTypeOpMap, + scf::ForOp &persistentForOp) { + // metrics: + // 1. Have more than one key type of operation. + // 2. persistent (all key operations are in one forOp) + DenseSet keyTypes; + DenseSet keyOperations; + agent->walk([&](Operation *op) { + auto typeId = getKeyTypeIdWrapper(op); + if (typeId >= 0 && op != agent) { + keyTypes.insert(typeId); + keyOperations.insert(op); + keyTypeOpMap[typeId].insert(op); + } + }); + + if (keyTypes.size() <= 1) { + return false; + } + + auto getPersistentFor = [&](DenseSet keyOps, + scf::ForOp &innerMostForOp) -> bool { + DenseSet commonForOps0, commonForOps1; + DenseSet *commonForOpsPre = &commonForOps0, + *commonForOpsPro = &commonForOps1; + assert(keyOps.size() > 1); + SmallVector forOps; + agent->walk( + [&](scf::ForOp forOp) { forOps.push_back(forOp); }); + + bool hasCommon = false; + for (auto &f : forOps) { + bool isCommon = true; + for (auto &k : keyOps) { + if (!f->isAncestor(k)) { + isCommon = false; + break; + } + } + if (isCommon) { + innerMostForOp = f; + hasCommon = true; + } + } + return hasCommon; + }; + + // Persistent agents with more than one key types are eligible. + return getPersistentFor(keyOperations, persistentForOp); +} + +void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp, + DenseMap> &keyTypeOpMap) { + // Modify keyTypeOpMap: DenseMap> --> DenseMap. Conservetively, assign each key operation one mutex. + // =======================detail description (TODO: to be + // deleted)========================== because it's hard to check if two + // operations with same typeid can share same mutex, we assign each key + // operation one mutex. To illustrate the hardness of this analysis, say we + // have two operations with same typeid: a and b, if there is another + // operation (say c) of different typeid between a and b, and their locations + // are a -- c -- b, then if the dependency is: + // * b depends on c, then a and b can NOT share the same mutex. + // * otherwise, a and b can share after move b before c. + // It would be more complicated when there are more types and operations. + DenseMap ProxyKeyTypeOpMap; + for (auto &[id, ops] : keyTypeOpMap) { + for (auto itr = ops.begin(); itr != ops.end(); ++itr) { + auto op = *itr; + ProxyKeyTypeOpMap[ProxyKeyTypeOpMap.size()] = op; + } + } + + int numRoles = ProxyKeyTypeOpMap.size(); + auto loc = ifOp.getLoc(); + OpBuilderWithAgentIds builder(ifOp.getContext()); + // Set num-roles for wsmaterialization pass + ifOp->setAttr("agent.num-roles", builder.getI32IntegerAttr(numRoles)); + builder.setAgentIdsFromOp(ifOp); + builder.setInsertionPointToStart(&(ifOp.getThenRegion().front())); + Value _0 = builder.create(loc, 0, 32); + Value curRoleId = + builder.createWithAgentIds(loc, numRoles); + Value isNotRole0 = builder.create( + loc, arith::CmpIPredicate::ne, curRoleId, _0); + + SmallVector mutexBarriers; + for (int i = 0; i < numRoles; ++i) { + auto v = builder.createWithAgentIds(loc); + mutexBarriers.push_back(v); + } + + // Update lower bound, step and pipelineIdx of persistentForOp + builder.setInsertionPoint(persistentForOp); + Value start = builder.createWithAgentIds( + loc, persistentForOp.getStep(), curRoleId); + Value oldLB = persistentForOp.getLowerBound(); + Value pipelineIdx = + persistentForOp->getOperand(persistentForOp->getNumOperands() - 1); + + start = builder.createWithAgentIds(loc, oldLB, start); + persistentForOp.setLowerBound(start); + + Value numRolesValue = + builder.createWithAgentIds(loc, numRoles, 32); + Value step = builder.createWithAgentIds( + loc, persistentForOp.getStep(), numRolesValue); + persistentForOp.setStep(step); + + Value newIdx = + builder.createWithAgentIds(loc, pipelineIdx, curRoleId); + persistentForOp.setIterArg(persistentForOp.getNumIterOperands() - 1, newIdx); + auto yield = + llvm::cast(persistentForOp.getBody()->getTerminator()); + auto idxPlusOneOp = + yield->getOperand(yield->getNumOperands() - 1).getDefiningOp(); + assert(isa(idxPlusOneOp)); + assert(idxPlusOneOp->getOperand(0) == + persistentForOp.getBody()->getArgument( + persistentForOp.getBody()->getNumArguments() - 1)); + idxPlusOneOp->setOperand(1, numRolesValue); + + // Add operations at the start of persistentForOp + builder.setInsertionPointToStart(persistentForOp.getBody()); + // If( role != 0 || !is_first_tile ) + Value isNotTileId0 = builder.create( + loc, arith::CmpIPredicate::ne, persistentForOp.getBody()->getArgument(0), + oldLB); + Value cond = builder.create(loc, isNotTileId0, isNotRole0); + + // Determine boundaries: get the largest exclusive op for each key op. + DenseMap lockLocs, unlockLocs; + DenseMap> parentOps; + for (int i = 0; i < numRoles; ++i) { + auto op = ProxyKeyTypeOpMap[i]->getParentOp(); + while (op != persistentForOp->getParentOp()) { + parentOps[i].push_back(op); + op = op->getParentOp(); + } + } + + std::map, std::pair::iterator, + SmallVector::iterator>> + rangeMap; + for (auto &[i, opsI] : parentOps) { + // Check exlusiveness + auto op = ProxyKeyTypeOpMap[i]; + for (auto &[j, opsJ] : parentOps) { + if (i == j) + continue; + auto pair = std::pair(i, j); + auto pairConj = std::pair(j, i); + auto end0 = rangeMap.count(pair) ? rangeMap[pair].first : opsI.end(); + auto end1 = rangeMap.count(pair) ? rangeMap[pair].second : opsJ.end(); + for (auto m = opsI.begin(); m != end0; ++m) { + auto itr = std::find(opsJ.begin(), end1, *m); + if (itr == end1) { + op = *m; + rangeMap[pair] = std::make_pair(m, itr); + rangeMap[pairConj] = rangeMap[pair]; + } else + goto exit; + } + } + exit:; + lockLocs[i] = op; + unlockLocs[i] = op; + } + + // Only cases where all lock/unlock locations are in same level make sense. + for (int i = 1; i < numRoles; ++i) { + if (lockLocs[i]->getParentOp() != lockLocs[i - 1]->getParentOp() || + unlockLocs[i]->getParentOp() != unlockLocs[i - 1]->getParentOp()) { + llvm_unreachable("Only cases where all locl/unlock locations are in same " + "level make sense"); + } + } + + // Extend boundaries: wait and release as early as possible + DenseMap prevTypeIds; + int prevId = -1; + persistentForOp->walk([&](Operation *op) { + for (int i = 0; i < numRoles; ++i) { + if (lockLocs[i] == op) { + prevTypeIds[i] = prevId; + prevId = i; + break; + } + } + }); + + // Update lockLocs + for (int i = 0; i < numRoles; ++i) { + if (prevTypeIds[i] == -1) + lockLocs[i] = cond.getDefiningOp(); + else + lockLocs[i] = unlockLocs[prevTypeIds[i]]; + } + // lock + for (int i = 0; i < numRoles; ++i) { + builder.setInsertionPointAfter(lockLocs[i]); + auto waitIfOp = builder.create(loc, cond); + builder.setInsertionPointToStart(&(waitIfOp.getThenRegion().front())); + builder.create(loc, mutexBarriers[i]); + } + + // unlock + for (int i = 0; i < numRoles; ++i) { + builder.setInsertionPointAfter(unlockLocs[i]); + builder.create(loc, mutexBarriers[i]); + } + + // Add attr "agent.mutex_role" for barrier analysis + int roleId = -1; + for (Operation &bodyOp : lockLocs[0]->getBlock()->getOperations()) { + Operation *op = &bodyOp; + if (roleId != -1) + op->walk([&](Operation *subOp) { + if (!isa(op) && !isa(op) && + !isa(op)) + subOp->setAttr("agent.mutex_role", builder.getI32IntegerAttr(roleId)); + }); + for (int i = 0; i < numRoles; ++i) { + if (lockLocs[i] == op) { + if (roleId != -1) + op->setAttr("agent.mutex_role", builder.getI32IntegerAttr(roleId)); + roleId = i; + break; + } + } + } +} + +//===----------------------------------------------------------------------===// +// WSMaterializationPass +//===----------------------------------------------------------------------===// + +struct WSMutexPass : public TritonGPUWSMutexBase { +public: + WSMutexPass() = default; + WSMutexPass(int computeCapability) { + this->computeCapability = computeCapability; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + mod.walk([&](triton::FuncOp funcOp) { + for (Operation &bodyOp : funcOp.getBody().front().getOperations()) { + Operation *op = &bodyOp; + scf::ForOp persistentForOp; + // premise: agent region is encapsulated with scf.if + if (isa(op) && getAgentIds(op).size() == 1) { + DenseMap> keyTypeOpMap; + if (isEligible(op, keyTypeOpMap, persistentForOp)) { + auto ifOp = cast(op); + mutexSync(mod, ifOp, persistentForOp, keyTypeOpMap); + } + } + } + }); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// createTritonNvidiaGPUWSMutexPass +//===----------------------------------------------------------------------===// + +std::unique_ptr +mlir::createTritonNvidiaGPUWSMutexPass(int computeCapability) { + return std::make_unique(computeCapability); +} diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp new file mode 100644 index 000000000000..459eff719fb7 --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp @@ -0,0 +1,957 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" + +#include "mlir/Analysis/SliceAnalysis.h" + +#include +#include + +using namespace mlir; +namespace ttg = triton::gpu; +namespace ttng = triton::nvidia_gpu; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { +struct Channel { +public: + using Relation = std::pair; + + Channel(int producer, int consumer, Operation *src, Operation *dst) + : relation(producer, consumer), srcOp(src), dstOp(dst) {} + + bool operator==(const Channel &c) { + return relation == c.relation && srcOp == c.srcOp && dstOp == c.dstOp; + } + + Relation relation; + Operation *srcOp; + Operation *dstOp; +}; + +//===----------------------------------------------------------------------===// +// createToken +//===----------------------------------------------------------------------===// + +DenseMap +createToken(const DenseMap> &map, + triton::FuncOp funcOp, int numStages) { + DenseMap ret; + OpBuilder builder(funcOp); + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + for (auto it = map.begin(); it != map.end(); ++it) { + Value v; + if (it->second.front()->srcOp->getParentOfType()) { + v = builder.create(funcOp.getLoc(), numStages); + } else { + // No need to pipeline + v = builder.create(funcOp.getLoc(), 1); + } + for (auto &c : it->second) { + ret[c] = v; + } + } + return ret; +} + +//===----------------------------------------------------------------------===// +// createBuffer +//===----------------------------------------------------------------------===// + +DenseMap createBuffer(const SmallVector &channels, + triton::FuncOp funcOp, int numStages) { + DenseMap bufferMap; + MLIRContext *context = funcOp.getContext(); + OpBuilder builder(funcOp); + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + for (const auto &c : channels) { + auto loadOp = dyn_cast(c->srcOp); + Value loadResult = loadOp.getResult(); + if (auto tensorType = loadResult.getType().dyn_cast()) { + // Get basic information from tensorType + auto order = ttg::getOrder(tensorType.getEncoding()); + auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); + auto elemType = tensorType.getElementType(); + + // Get shape, layout and type of a slice + auto sliceShape = tensorType.getShape(); + auto sharedLayout = ttg::SharedEncodingAttr::get( + context, sliceShape, order, CTALayout, elemType); + auto sliceType = + RankedTensorType::get(sliceShape, elemType, sharedLayout); + + // Get shape, layout and type of the complete buffer + SmallVector bufferShape(sliceShape.begin(), sliceShape.end()); + if (loadOp->getParentOfType()) { + bufferShape.insert(bufferShape.begin(), numStages); + } else { + // No need to pipeline + bufferShape.insert(bufferShape.begin(), 1); + } + auto bufferType = + RankedTensorType::get(bufferShape, elemType, sharedLayout); + Value buffer = + builder.create(funcOp.getLoc(), bufferType); + bufferMap[c] = buffer; + } else { + llvm_unreachable("Unexpected result type"); + } + } + return bufferMap; +} + +//===----------------------------------------------------------------------===// +// appendPipelineIdxToLoopArgs +//===----------------------------------------------------------------------===// + +scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages, + scf::ForOp &parentForOp) { + auto loc = forOp.getLoc(); + Block *body = forOp.getBody(); + + // The agentId set of pipelineIdx is the union of agentId sets of all ops in + // the for loop + OpBuilderWithAgentIds builder(forOp.getContext()); + builder.setAgentIdsFromArray(collectAgentIds(forOp)); + + builder.setInsertionPoint(forOp); + Value numStagesVal = + builder.createWithAgentIds(loc, numStages, 32); + // Append pipelineIdx to block arguments + Value pipelineIdx = + body->insertArgument(body->getNumArguments(), builder.getI32Type(), loc); + + // pipelineIdx = (pipelineIdx + 1) % numStages + auto yieldOp = llvm::cast(body->getTerminator()); + builder.setInsertionPoint(yieldOp); + Value one = builder.createWithAgentIds(loc, 1, 32); + + Value pipelineIdxPlusOne = + builder.createWithAgentIds(loc, pipelineIdx, one); + + // Append pipelineIdx to yield operands + yieldOp->insertOperands(yieldOp.getNumOperands(), {pipelineIdxPlusOne}); + + // Copy iter operands of forOp + SmallVector newLoopArgs; + for (auto operand : forOp.getIterOperands()) + newLoopArgs.push_back(operand); + + // Append initial value of pipelineIdx to newLoopArgs + builder.setInsertionPoint(forOp); + Value initValue; + if (parentForOp) { + // Make sure prior pipelineIdx is inserted in the end of parentForOp + initValue = parentForOp.getBody()->getArguments().back(); + Value numSteps = builder.createWithAgentIds( + loc, forOp.getUpperBound(), forOp.getLowerBound()); + auto one = builder.createWithAgentIds(loc, 1, 32); + numSteps = builder.createWithAgentIds(loc, numSteps, + forOp.getStep()); + numSteps = builder.createWithAgentIds(loc, numSteps, one); + numSteps = builder.createWithAgentIds(loc, numSteps, + forOp.getStep()); + initValue = + builder.createWithAgentIds(loc, initValue, numSteps); + } else { + initValue = builder.createWithAgentIds(loc, 0, 32); + } + newLoopArgs.push_back(initValue); + + // Create newForOp and take the region of forOp + auto newForOp = builder.createWithAgentIds( + loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), + newLoopArgs); + newForOp.getRegion().takeBody(forOp.getRegion()); + + // Replace forOp with newForOp + for (unsigned i = 0; i < forOp.getNumResults(); ++i) + forOp.getResult(i).replaceAllUsesWith(newForOp.getResult(i)); + forOp.erase(); + + return newForOp; +} + +//===----------------------------------------------------------------------===// +// appendPipelineIdxArgs +//===----------------------------------------------------------------------===// + +void appendPipelineIdxArgs(SmallVector &backbone, int numStages) { + + SmallVector orderedForOps; + for (auto &op : backbone) { + op->walk([&](Operation *subOp) { + if (auto forOp = dyn_cast(subOp)) { + orderedForOps.push_back(forOp); + } + }); + } + + for (auto &op : orderedForOps) { + scf::ForOp parentForOp = op->getParentOfType(); + auto newForOp = appendPipelineIdxToLoopArgs(op, numStages, parentForOp); + auto backboneForItr = + std::find(backbone.begin(), backbone.end(), op.getOperation()); + if (backboneForItr != backbone.end()) { + // Update backbone + *backboneForItr = newForOp.getOperation(); + } + } +} + +//===----------------------------------------------------------------------===// +// checkDependencyAndCollectUsedArgs +//===----------------------------------------------------------------------===// + +SmallVector checkDependencyAndCollectUsedArgs( + scf::ForOp forOp, AgentId agentId, + DenseMap &blockArgToYieldOperand) { + + std::unordered_set visited; + SetVector argSet; + + // DFS + std::function dfs = [&](Operation *op) { + if (visited.find(op) != visited.end()) + return; + visited.insert(op); + for (Value operand : op->getOperands()) { + if (auto blockArg = operand.dyn_cast()) { + if (!blockArgToYieldOperand[blockArg]) + continue; + argSet.insert(blockArg.getArgNumber() - forOp.getNumInductionVars()); + operand = blockArgToYieldOperand[blockArg]; + } + Operation *depOp = operand.getDefiningOp(); + assert(depOp && "Unexpected Value with no defining op"); + if (depOp->getBlock() != forOp.getBody()) + continue; + assert(hasAgentId(depOp, agentId) && "Dependency error"); + dfs(depOp); + } + }; + + // Start from operations that are marked with this agentId explicitly and + // check dependency with DFS traversal + forOp.walk([&](Operation *op) { + if (hasAgentId(op, agentId) && !isa(op)) + dfs(op); + }); + + // Collect used block args + SmallVector args(argSet.begin(), argSet.end()); + llvm::sort(args); + return args; +} + +//===----------------------------------------------------------------------===// +// createForOpsForEachAgentId +//===----------------------------------------------------------------------===// + +DenseMap createForOpsForEachAgentId(scf::ForOp forOp) { + // Collect operation list for each agentId + DenseMap> opList; + for (Operation &op : forOp.getBody()->without_terminator()) + for (AgentId agentId : getAgentIds(&op)) + opList[agentId].push_back(&op); + + // Prepare blockArgToYieldOperand mapping + DenseMap blockArgToYieldOperand; + auto yieldOp = llvm::cast(forOp.getBody()->getTerminator()); + assert(yieldOp.getNumOperands() == forOp.getNumRegionIterArgs()); + for (unsigned i = 0; i < forOp.getNumRegionIterArgs(); ++i) + blockArgToYieldOperand[forOp.getRegionIterArg(i)] = yieldOp.getOperand(i); + + auto loc = forOp.getLoc(); + OpBuilderWithAgentIds builder(forOp.getContext()); + DenseMap agentsToForOp; + + // Create newForOp for each agent + for (AgentId agentId : collectAgentIds(forOp)) { + auto usedArgs = checkDependencyAndCollectUsedArgs(forOp, agentId, + blockArgToYieldOperand); + + // Prepare newLoopArgs + SmallVector newLoopArgs; + for (unsigned argNumber : usedArgs) + newLoopArgs.push_back(forOp.getIterOperands()[argNumber]); + + // Create newForOp + builder.setAgentIdsFromArray({agentId}); + builder.setInsertionPoint(forOp); + auto newForOp = builder.createWithAgentIds( + loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), + newLoopArgs); + + // Initialize Value mapping from forOp to newForOp + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + for (unsigned i = 0; i < usedArgs.size(); ++i) { + auto oldArg = forOp.getRegionIterArgs()[usedArgs[i]]; + auto newArg = newForOp.getRegionIterArgs()[i]; + mapping.map(oldArg, newArg); + } + + // Clone all operations with this agentId to newForOp + builder.setInsertionPointToStart(newForOp.getBody()); + for (Operation *op : opList[agentId]) { + Operation *newOp = builder.clone(*op, mapping); + setAgentIds(newOp, {agentId}); + for (unsigned i = 0; i < op->getNumResults(); ++i) + mapping.map(op->getResult(i), newOp->getResult(i)); + } + + // Create YieldOp for newForOp + SmallVector newYieldOperands; + for (unsigned i : usedArgs) + newYieldOperands.push_back(mapping.lookup(yieldOp.getOperand(i))); + auto newYieldOp = + builder.create(yieldOp.getLoc(), newYieldOperands); + setAgentIds(newYieldOp, {agentId}); + + // Replace results of forOp with results of newForOp + for (unsigned i = 0; i < usedArgs.size(); ++i) { + auto oldResult = forOp.getResult(usedArgs[i]); + auto newResult = newForOp.getResult(i); + oldResult.replaceAllUsesWith(newResult); + } + + agentsToForOp[agentId] = newForOp; + } + + return agentsToForOp; +} + +//===----------------------------------------------------------------------===// +// createIfOpsForEachAgentId +//===----------------------------------------------------------------------===// + +DenseMap createIfOpsForEachAgentId(scf::IfOp ifOp) { + // TODO: to be implemented + OpBuilderWithAgentIds builder(ifOp.getContext()); + DenseMap agentsToIfOp; + return agentsToIfOp; +} + +//===----------------------------------------------------------------------===// +// SpecializeAgentRegion +//===----------------------------------------------------------------------===// + +DenseMap SpecializeAgentRegion(triton::FuncOp funcOp) { + MLIRContext *context = funcOp.getContext(); + OpBuilder builder(context); + auto loc = funcOp.getLoc(); + + // Get block from funcOp + Block *block = &funcOp.getBody().front(); + auto returnOp = llvm::cast(block->getTerminator()); + + // Collect original operations + SmallVector opList; + for (Operation &op : block->getOperations()) + opList.push_back(&op); + + // Get curAgentId + builder.setInsertionPoint(returnOp); + Value curAgentId = builder.create(loc); + + // Resources for each agentId + DenseMap> agentsToBuilders; + DenseMap agentsToIfOp; + DenseMap agentsToIRMappings; + + for (AgentId agentId : collectAgentIds(funcOp)) { + // Create IfOp for each agentId + Value cond = builder.create( + loc, arith::CmpIPredicate::eq, curAgentId, + builder.create(loc, agentId, 32)); + + auto ifOp = builder.create(loc, cond); + agentsToIfOp[agentId] = ifOp; + setAgentIds(ifOp, {agentId}); + + // Create OpBuilderWithAgentIds for each agent + auto agentBuilder = std::make_shared(context); + agentsToBuilders[agentId] = agentBuilder; + agentBuilder->setAgentIdsFromArray({agentId}); + + // Set insertion point before yieldOp + auto yieldOp = ifOp.thenYield(); + setAgentIds(yieldOp, {agentId}); + agentBuilder->setInsertionPoint(yieldOp); + } + + // Clone all operations into corresponding if blocks + SmallVector cloned; + for (Operation *op : opList) { + auto agentIds = getAgentIds(op); + if (!agentIds.empty()) { + cloned.push_back(op); + for (AgentId agentId : getAgentIds(op)) { + IRMapping &mapping = agentsToIRMappings[agentId]; + Operation *newOp = agentsToBuilders[agentId]->clone(*op, mapping); + for (unsigned i = 0; i < op->getNumResults(); ++i) + mapping.map(op->getResult(i), newOp->getResult(i)); + } + } + } + + // Remove original operations that have been cloned in reverse order + for (auto it = cloned.rbegin(); it != cloned.rend(); ++it) { + Operation *op = *it; + op->erase(); + } + + return agentsToIfOp; +} + +//===----------------------------------------------------------------------===// +// collectAsyncChannels +//===----------------------------------------------------------------------===// + +void collectAsyncChannels(SmallVector> &channels, + triton::FuncOp &funcOp) { + funcOp.walk([&](Operation *op) { + for (auto result : op->getResults()) { + if (result.use_empty() || !op->hasAttr("async_agent")) { + continue; + } + auto producerAgent = + op->getAttrOfType("async_agent"); + if (producerAgent.getValues().size() > 1) { + continue; + } + for (Operation *userOp : result.getUsers()) { + if (!userOp->hasAttr("async_agent") || + userOp->getAttrOfType("async_agent") + .getValues() + .size() > 1) { + continue; + } + auto consumerAgentId = + userOp->getAttrOfType("async_agent") + .getValues()[0]; + auto producerAgentId = producerAgent.getValues()[0]; + if (producerAgentId != consumerAgentId) { + channels.push_back(std::make_unique( + producerAgentId, consumerAgentId, op, userOp)); + } + } + } + }); +} + +//===----------------------------------------------------------------------===// +// reduceChannels +//===----------------------------------------------------------------------===// + +void reduceChannels(SmallVector &channels, + + DenseMap> &map) { + // If producers or their consumers has the same convergent comsumer, + // and those producers, producers' consumers and the convergent comsumer are + // in the same block, They share the same token. + auto checkConverge = [](Operation *op1, Operation *op2) -> Operation * { + // Only check level-0 and level-1 convergence, e.g. + // producer: load0 load1 + // | | + // consumer: convertLayout0 convertLayout1 + // \ / + // consumer: dot + // The example above is level-1 convergence. + // If convertLayoutOps converge in deeper depth, this function will + // fail to detect. + // TODO: implement general level-N convergence. + if (op1 == op2) { + return op1; + } + if (op1->getBlock() == op2->getBlock() && op1->hasOneUse() && + op2->hasOneUse() && + *(op1->getUsers().begin()) == *(op2->getUsers().begin()) && + (*(op1->getUsers().begin()))->getBlock() == op1->getBlock()) { + return *(op1->getUsers().begin()); + } + return nullptr; + }; + assert(channels.size() > 0 && "channel size is zero"); + // Compare with existing channels in map + for (auto c0 = channels.begin(); c0 != channels.end(); ++c0) { + bool isConvergent = false; + for (auto &kv : map) { + if (kv.second.size() > 0 && + (*c0)->srcOp->getBlock() == kv.second.front()->srcOp->getBlock()) { + if (auto cvg = checkConverge((*c0)->dstOp, kv.second.front()->dstOp)) { + kv.second.push_back(*c0); + isConvergent = true; + break; + } + } + } + if (!isConvergent) { + map[(*c0)->dstOp].push_back(*c0); + } + } + + // Reorder channels and maps based on locations of producers + for (auto &kv : map) { + if (kv.second.size() > 1) { + auto &allOps = kv.second.front()->srcOp->getBlock()->getOperations(); + std::sort( + kv.second.begin(), kv.second.end(), [&](Channel *a, Channel *b) { + auto itrA = + std::find_if(allOps.begin(), allOps.end(), [&](Operation &op) { + Operation *opPointer = &op; + return opPointer == a->srcOp; + }); + auto itrB = + std::find_if(allOps.begin(), allOps.end(), [&](Operation &op) { + Operation *opPointer = &op; + return opPointer == b->srcOp; + }); + assert(itrA != allOps.end() && itrB != allOps.end()); + return std::distance(itrA, itrB) < 0; + }); + } + } +} + +//===----------------------------------------------------------------------===// +// getBackbone +//===----------------------------------------------------------------------===// + +SmallVector getBackbone(triton::FuncOp funcOp, + const SmallVector &channels) { + // Backbone: outermost Ops with regions in funcOp which contain at least one + // relation between producer and consumer. It assumes producer-consumer + // relation going across two outermost Ops in funcOp is forbidden. For + // example, In the example of runOnOperation(), only the outermost ForOp is + // backbone, the inner ForOp is not. + SmallVector backboneOps; + auto isBackbone = [&](Operation *backbone) -> bool { + for (auto c : channels) { + Operation *producer = c->srcOp, *consumer = c->dstOp; + while (producer && !isa(producer->getParentOp())) { + producer = producer->getParentOp(); + } + while (consumer && !isa(consumer->getParentOp())) { + consumer = consumer->getParentOp(); + } + if (producer == backbone && consumer == backbone) { + return true; + } + assert((producer != backbone || + isa(producer->getParentOp())) && + (consumer != backbone || + isa(consumer->getParentOp())) && + "Error: producer and consumer belongs to different backboneOps"); + } + return false; + }; + Operation *op; + for (Operation &bodyOp : funcOp.getBody().front().getOperations()) { + op = &bodyOp; + if (op->getNumRegions() > 0) { + // If this op as a whole is a producer or consumer, continue + if (getAgentIds(op).size() == 1) { + continue; + } + if (isBackbone(op)) { + backboneOps.push_back(op); + } + } + } + return backboneOps; +} + +//===----------------------------------------------------------------------===// +// buildAsyncComm +//===----------------------------------------------------------------------===// + +void buildAsyncComm(const DenseMap> &map, + const DenseMap &tokenMap, + const DenseMap &bufferMap, + int numStages) { + + auto getSameLevelOp = [](Operation *p, Operation *c) -> Operation * { + while (!isa(c)) { + if (c->getParentOp() == p->getParentOp()) { + return c; + } + c = c->getParentOp(); + } + llvm_unreachable("Falied to find consumer's same level Op with producer"); + }; + + auto consumerReleaseHeutistic = [&](Operation *p, + Operation *c) -> Operation * { + if (c->getBlock() == p->getBlock()) { + auto consumerAgentId = + c->getAttrOfType("async_agent") + .getValues()[0]; + for (auto it = c->getBlock()->rbegin(); it != c->getBlock()->rend(); + ++it) { + if (!it->hasAttr("async_agent")) { + continue; + } + auto asyncAttr = it->getAttrOfType("async_agent") + .getValues(); + if (asyncAttr.size() == 1 && asyncAttr[0] == consumerAgentId) { + return &(*it); + } + } + return nullptr; + } else { + return getSameLevelOp(p, c); + } + }; + + auto getAgents = [&](Operation *p, Operation *c, SmallVector &agentP, + SmallVector &agentC, + SmallVector &agentsPC) -> void { + agentP = collectAgentIds(p); + agentC = collectAgentIds(c); + agentsPC.reserve(agentP.size() + agentC.size()); + agentsPC.insert(agentsPC.end(), agentP.begin(), agentP.end()); + agentsPC.insert(agentsPC.end(), agentC.begin(), agentC.end()); + }; + // TODO: try to optimize locations of arriving and waiting token + // for fused-attention + for (auto kv : map) { + /*****************Token related*****************/ + auto headProducer = kv.second.front()->srcOp; + auto tailProducer = kv.second.back()->srcOp; + auto headConsumer = kv.second.front()->dstOp; + auto tailConsumer = kv.second.back()->dstOp; + auto token = tokenMap.find(kv.second.front())->second; + SmallVector agentP, agentC, agentsPC; + getAgents(headProducer, headConsumer, agentP, agentC, agentsPC); + OpBuilderWithAgentIds builder(headProducer->getContext()); + + if (auto funcOp = dyn_cast(headProducer->getParentOp())) { + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + } else { + builder.setInsertionPoint(headProducer->getParentOp()); + } + builder.setAgentIdsFromArray(agentsPC); + Value pipelineIdx; + Value numStagesVal = builder.createWithAgentIds( + headProducer->getLoc(), numStages, 32); + if (auto forOp = headProducer->getParentOfType()) { + pipelineIdx = forOp.getBody()->getArguments().back(); + } else { + // existing"); + pipelineIdx = builder.createWithAgentIds( + headProducer->getLoc(), 0, 32); + } + + // insert ProducerAcquireOp + builder.setInsertionPoint(headProducer); + if (headProducer->getParentOfType()) { + pipelineIdx = builder.createWithAgentIds( + headProducer->getLoc(), pipelineIdx, numStagesVal); + } + builder.setAgentIdsFromArray(agentP); + builder.createWithAgentIds(headProducer->getLoc(), + token, pipelineIdx); + + // insert ProducerCommitOp + builder.setInsertionPointAfter(tailProducer); + builder.createWithAgentIds(tailProducer->getLoc(), + token, pipelineIdx); + + builder.setAgentIdsFromArray(agentC); + // insert ConsumerWaitOp + auto consumerWaitPoint = getSameLevelOp(headProducer, headConsumer); + builder.setInsertionPoint(consumerWaitPoint); + builder.createWithAgentIds(headConsumer->getLoc(), + token, pipelineIdx); + + // insert ConsumerReleaseOp + auto consumerReleasePoint = + consumerReleaseHeutistic(tailProducer, tailConsumer); + builder.setInsertionPointAfter(consumerReleasePoint); + builder.createWithAgentIds( + consumerReleasePoint->getLoc(), token, pipelineIdx); + + /*****************Buffer related*****************/ + /// splitLoadsInForLoop + for (auto &c : kv.second) { + assert(isa(c->srcOp) && "prodcuerOp is not tt.load"); + auto loadOp = cast(c->srcOp); + auto buffer = bufferMap.find(c)->second; + MLIRContext *context = loadOp->getContext(); + OpBuilderWithAgentIds builder(context); + builder.setInsertionPoint(loadOp->getParentOp()); + builder.setAgentIdsFromArray(agentsPC); + + builder.setInsertionPoint(loadOp); + Value loadResult = loadOp.getResult(); + if (auto tensorType = loadResult.getType().dyn_cast()) { + // Get basic information from tensorType + auto order = ttg::getOrder(tensorType.getEncoding()); + auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); + auto elemType = tensorType.getElementType(); + + // Get shape, layout and type of a slice + auto sliceShape = tensorType.getShape(); + auto sharedLayout = ttg::SharedEncodingAttr::get( + context, sliceShape, order, CTALayout, elemType); + auto sliceType = + RankedTensorType::get(sliceShape, elemType, sharedLayout); + + // Get shape, layout and type of the complete buffer + SmallVector bufferShape(sliceShape.begin(), sliceShape.end()); + if (loadOp->getParentOfType()) { + bufferShape.insert(bufferShape.begin(), numStages); + } else { + bufferShape.insert(bufferShape.begin(), 1); + } + auto bufferType = + RankedTensorType::get(bufferShape, elemType, sharedLayout); + + // Create InsertSliceOp + builder.setAgentIdsFromOp(loadOp); + builder.setInsertionPointAfter(loadOp); + auto insertSliceOp = builder.createWithAgentIds( + /*loc=*/loadOp.getLoc(), /*result=*/bufferType, + /*src=*/loadOp.getPtr(), /*dst=*/buffer, /*index=*/pipelineIdx, + /*mask=*/loadOp.getMask(), /*other=*/loadOp.getOther(), + /*cache=*/loadOp.getCache(), /*evict=*/loadOp.getEvict(), + /*isVolatile=*/loadOp.getIsVolatile(), /*axis=*/0); + + // Create ExtractSliceOp + auto attr = [&](int val) { return builder.getI64IntegerAttr(val); }; + SmallVector offsets = {pipelineIdx, attr(0), attr(0)}; + SmallVector sizes = {attr(1), attr(sliceShape[0]), + attr(sliceShape[1])}; + SmallVector strides = {attr(1), attr(1), attr(1)}; + builder.setAgentIdsFromValueUsers(loadResult); + builder.setInsertionPoint(c->dstOp); + auto extractSliceOp = builder.createWithAgentIds( + loadOp.getLoc(), sliceType, buffer, offsets, sizes, strides); + + // Replace all uses of loadResult + loadResult.replaceAllUsesWith(extractSliceOp.getResult()); + loadOp.erase(); + } + } + } +} + +//===----------------------------------------------------------------------===// +// agentDivision +//===----------------------------------------------------------------------===// + +DenseMap agentDivision(Operation *backbone) { + // A general agent division in backbone could be: + // * If opWithRegion has results, e.g. scf.for, this opWithRegion will be + // splitted into several new operations, each agent has one, which + // has the part of results related to this agent. One agent could own + // all orginal results or none of them, but one result must belong to + // one and only one agent. + // * if opWithRegions doesn't have result. Simply split for every agent. + // * So does operands of opWithRegions + // However, current backbones are all ForOps and IfOps. So we customize + // the implementation. + DenseMap agentBackbone; + backbone->walk([&](Operation *op) { + auto ids = getAgentIds(op); + if (op->getNumRegions() > 0 && ids.size() > 1) { + // ForOp: change iterArgs and yield results + if (auto forOp = dyn_cast(op)) { + auto forOps = createForOpsForEachAgentId(forOp); + if (op == backbone) { + for (auto kv : forOps) { + auto f = kv.second; + auto id = getAgentIds(f.getOperation()); + assert(id.size() == 1 && + "generated ForOp doesn't have one and only one agentId"); + agentBackbone[id.front()] = f.getOperation(); + } + } + forOp.erase(); + } else if (auto ifOp = dyn_cast(op)) { + // TODO: to be implemented + llvm_unreachable("If Op is unsupported"); + auto ifOps = createIfOpsForEachAgentId(ifOp); + assert(ifOps.size() > 0); + if (op == backbone) { + for (auto kv : ifOps) { + auto i = kv.second; + auto id = getAgentIds(i.getOperation()); + assert(id.size() == 1 && + "generated IfOp doesn't have one and only one agentId"); + agentBackbone[id.front()] = i.getOperation(); + } + } + } else { + llvm_unreachable("Unexpected Op with regions"); + } + } + }); + assert(agentBackbone.size() > 0 && "Agent division failed"); + return agentBackbone; +} + +//===----------------------------------------------------------------------===// +// cloneBackboneForEachAgentId +//===----------------------------------------------------------------------===// + +void cloneBackboneForEachAgentId(SmallVector &backbone) { + SmallVector newBackBone; + + for (Operation *op : backbone) { + auto loc = op->getLoc(); + OpBuilderWithAgentIds builder(op->getContext()); + builder.setInsertionPoint(op); + // First, agent division + DenseMap agentBackbone = agentDivision(op); + + // Second, remove irrelavant Ops + for (auto kv : agentBackbone) { + SmallVector deleteOps; + AgentId targetId = kv.first; + Operation *newBackbone = kv.second; + newBackbone->walk([&](Operation *subOp) { + auto ids = getAgentIds(subOp); + if (std::find(ids.begin(), ids.end(), targetId) == ids.end()) { + deleteOps.push_back(subOp); + } + }); + for (auto it = deleteOps.rbegin(); it != deleteOps.rend(); ++it) { + (*it)->erase(); + } + } + } +} + +//===----------------------------------------------------------------------===// +// WSPipelinePass +//===----------------------------------------------------------------------===// + +struct WSPipelinePass : public TritonGPUWSPipelineBase { + WSPipelinePass() = default; + WSPipelinePass(int numStages, int numWarps, int computeCapability) { + this->numStages = numStages; + this->numWarps = numWarps; + this->computeCapability = computeCapability; + } + + void runOnOperation() override { + auto mod = getOperation(); + if (!ttng::TritonNvidiaGPUDialect::getWSSupportedAttr(mod)) + return signalPassFailure(); + + mod.walk([&](triton::FuncOp funcOp) { + assert(funcOp.getBody().hasOneBlock() && + "FuncOp with more than one blocks is not supported"); + // Maintain all structures between funcOp and producer/consumer Op, for + // example: + /* +-----------------------------------+ + * | scf.for: | + * | A = tt.load {agentId = 0} | + * | scf.for: | + * | B = tt.load {agentId = 0} | + * | C = tt.dot A, B {agentId = 1} | + * +-----------------------------------+ + * || + * \||/ + * \/ + * +-----------------------------------------+ + * | token0 = create_token() | + * | token1 = create_token() | + * | buffer0 = alloc_buffer() | + * | buffer1 = alloc_buffer() | + * | if agent0: | + * | scf.for: | + * | producer_aquire token0 | + * | buffer0 = tt.load (load A)| + * | producer_commit token0 | + * | scf.for: | + * | producer_aquire token1 | + * | buffer1 = tt.load (load B)| + * | producer_commit token1 | + * | if agent1: | + * | scf.for: | + * | consumer_wait token0 | + * | scf.for: | + * | consumer_wait token1 | + * | A = extract_slice buffer0 | + * | B = extract_slice buffer1 | + * | C = tt.dot A, B | + * | consumer_arrive token1 | + * | consumer_arrive token0 | + * +-----------------------------------------+ + */ + + // First step: collect channels + SmallVector> channelsOrigin; + collectAsyncChannels(channelsOrigin, funcOp); + SmallVector channels; + for (const auto &c : channelsOrigin) { + channels.push_back(c.get()); + } + + // cvgOp-channels map + DenseMap> map; + reduceChannels(channels, map); + + // Prepare phase, getBackbone, appendPipelineIdxArgs + SmallVector backbone = getBackbone(funcOp, channels); + appendPipelineIdxArgs(backbone, numStages); + + // Create token, buffer and data tranfer between async agents + DenseMap tokenMap = createToken(map, funcOp, numStages); + DenseMap bufferMap = + createBuffer(channels, funcOp, numStages); + buildAsyncComm(map, tokenMap, bufferMap, numStages); + + // Clone backbone, remove irrelevant blockArgument for {forOp, ifOp} + cloneBackboneForEachAgentId(backbone); + + // Specialize agent region + SpecializeAgentRegion(funcOp); + }); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// createTritonNvidiaGPUWSPipelinePass +//===----------------------------------------------------------------------===// + +std::unique_ptr +mlir::createTritonNvidiaGPUWSPipelinePass(int numStages, int numWarps, + int computeCapability) { + return std::make_unique(numStages, numWarps, + computeCapability); +} diff --git a/lib/Target/HSACO/HSACOTranslation.cpp b/lib/Target/HSACO/HSACOTranslation.cpp index d2c318b9628c..45f793b3534d 100644 --- a/lib/Target/HSACO/HSACOTranslation.cpp +++ b/lib/Target/HSACO/HSACOTranslation.cpp @@ -33,8 +33,10 @@ #include "llvm/Transforms/Utils/Cloning.h" #include #include +#include #include #include +#include namespace { @@ -171,11 +173,34 @@ std::string generate_hsaco(llvm::Module *module, const std::string &triple, std::filesystem::path hsaco(kernel_name + ".hsaco"); std::string hsaco_path = (kernel_dir / hsaco).string(); std::string error_message; - std::string lld_path = "/opt/rocm/llvm/bin/ld.lld"; - int lld_result = llvm::sys::ExecuteAndWait( - lld_path, - {lld_path, "-flavor", "gnu", "-shared", "-o", hsaco_path, isabin_path}, - std::nullopt, {}, 0, 0, &error_message); + + // Check in triton/third_party/rocm/llvm/bin first. For whls this will be the + // correct location. If not found, go back to using ROCM_PATH or /opt/rocm + static const auto this_library_path = [] { + Dl_info fileinfo; + if (dladdr(reinterpret_cast(generate_hsaco), &fileinfo) == 0) { + return std::filesystem::path(); + } + return std::filesystem::path(fileinfo.dli_fname); + }(); + + static const auto compiletime_path = this_library_path.parent_path() + .parent_path() + .parent_path() / + "triton" / "third_party" / + "rocm" / "llvm" / "bin" / "ld.lld"; + std::string lld_path = compiletime_path.string(); + if (!std::filesystem::exists(lld_path)) { + std::string rocm_path = ::triton::tools::getenv("ROCM_PATH"); + lld_path = (rocm_path.empty()) ? ROCM_DEFAULT_DIR : rocm_path; + lld_path += "/llvm/bin/ld.lld"; + } + + int lld_result = + llvm::sys::ExecuteAndWait(lld_path, + {lld_path, "-flavor", "gnu", + "-shared", "-o", hsaco_path, isabin_path}, + std::nullopt, {}, 0, 0, &error_message); if (lld_result) { llvm::errs() << "ld.lld execute fail: " << '\n' << error_message << "Code: " diff --git a/lib/Target/LLVMIR/CMakeLists.txt b/lib/Target/LLVMIR/CMakeLists.txt index 5a687f84bc6a..fbaefe68375c 100644 --- a/lib/Target/LLVMIR/CMakeLists.txt +++ b/lib/Target/LLVMIR/CMakeLists.txt @@ -25,3 +25,8 @@ add_mlir_translation_library(TritonLLVMIR MLIRTargetLLVMIRExport TritonGPUToLLVM ) + +set_source_files_properties( + LLVMIRTranslation.cpp + PROPERTIES + COMPILE_FLAGS "-D__BUILD_DIR__=\\\"${CMAKE_BINARY_DIR}\\\"") diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 30a238e773cb..45f0fedaf898 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -1,6 +1,9 @@ #include "triton/Target/LLVMIR/LLVMIRTranslation.h" #include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" @@ -15,9 +18,11 @@ #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "mlir/Transforms/Passes.h" +#include "triton/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.h" #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Target/LLVMIR/Passes.h" +#include "triton/Target/PTX/TmaMetadata.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/IR/CallingConv.h" #include "llvm/ADT/APInt.h" @@ -54,7 +59,8 @@ struct NVVMMetadata { // Add the nvvm related metadata to LLVM IR. static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata, - bool isROCM, const int threadsPerCTA) { + Target target, const int threadsPerCTA, + const int wavesPerEU) { auto *module = func->getParent(); auto &ctx = func->getContext(); @@ -84,19 +90,24 @@ static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata, } if (metadata.isKernel) { - if (isROCM) { - func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); - func->addFnAttr("amdgpu-flat-work-group-size", - "1, " + std::to_string(threadsPerCTA)); - func->addFnAttr("denormal-fp-math-f32", "preserve-sign"); - func->addFnAttr("amdgpu-unsafe-fp-atomics", "true"); - } else { + switch (target) { + case Target::NVVM: { llvm::Metadata *mdArgs[] = { llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"), llvm::ValueAsMetadata::get( llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1))}; module->getOrInsertNamedMetadata("nvvm.annotations") ->addOperand(llvm::MDNode::get(ctx, mdArgs)); + } break; + case Target::ROCDL: { + func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); + func->addFnAttr("amdgpu-flat-work-group-size", + "1, " + std::to_string(threadsPerCTA)); + if (wavesPerEU > 0) + func->addFnAttr("amdgpu-waves-per-eu", std::to_string(wavesPerEU)); + func->addFnAttr("denormal-fp-math-f32", "preserve-sign"); + func->addFnAttr("amdgpu-unsafe-fp-atomics", "true"); + } break; } } } @@ -241,8 +252,8 @@ static void linkLibdevice(llvm::Module &module) { module.addModuleFlag(reflect); } -static bool linkExternLib(llvm::Module &module, llvm::StringRef name, - llvm::StringRef path, bool isROCM) { +bool linkExternLib(llvm::Module &module, llvm::StringRef name, + llvm::StringRef path, Target target) { llvm::SMDiagnostic err; auto &ctx = module.getContext(); @@ -261,8 +272,7 @@ static bool linkExternLib(llvm::Module &module, llvm::StringRef name, return true; } - // check if ROCM - if (!isROCM) { + if (target == Target::NVVM) { if (name == "libdevice") { linkLibdevice(module); } @@ -276,12 +286,13 @@ static bool linkExternLib(llvm::Module &module, llvm::StringRef name, std::unique_ptr translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module, - bool isROCM) { + Target target, int wavesPerEU) { DialectRegistry registry; mlir::registerBuiltinDialectTranslation(registry); mlir::registerLLVMDialectTranslation(registry); mlir::registerROCDLDialectTranslation(registry); mlir::registerNVVMDialectTranslation(registry); + module->getContext()->appendDialectRegistry(registry); llvm::DenseMap nvvmMetadata; @@ -303,7 +314,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module, // dead code. auto externLibs = getExternLibs(module); for (auto &lib : externLibs) { - if (linkExternLib(*llvmModule, lib.first, lib.second, isROCM)) + if (linkExternLib(*llvmModule, lib.first, lib.second, target)) return nullptr; } @@ -323,7 +334,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module, for (auto &func : llvmModule->functions()) { auto it = nvvmMetadata.find(func.getName()); if (it != nvvmMetadata.end()) - amendLLVMFunc(&func, it->second, isROCM, threadsPerCTA); + amendLLVMFunc(&func, it->second, target, threadsPerCTA, wavesPerEU); } return llvmModule; @@ -332,7 +343,8 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module, std::unique_ptr translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module, int computeCapability, - bool isROCM) { + mlir::triton::gpu::TMAMetadataTy &tmaInfos, + Target target, int wavesPerEU) { mlir::PassManager pm(module->getContext()); mlir::registerPassManagerCLOptions(); if (failed(applyPassManagerCLOptions(pm))) { @@ -354,7 +366,11 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertIndexToLLVMPass()); - pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability, isROCM)); + pm.addPass( + createConvertTritonGPUToLLVMPass({computeCapability, &tmaInfos, target})); +#ifndef USE_ROCM + pm.addPass(createConvertNVGPUToLLVMPass()); +#endif pm.addPass(mlir::createArithToLLVMConversionPass()); pm.addPass(mlir::createCanonicalizerPass()); // Simplify the IR @@ -372,7 +388,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, return nullptr; } - auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, isROCM); + auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, target, wavesPerEU); if (!llvmIR) { llvm::errs() << "Translate to LLVM IR failed"; return nullptr; diff --git a/lib/Target/PTX/CMakeLists.txt b/lib/Target/PTX/CMakeLists.txt index 69aa5710cdd1..c03ccca74ff5 100644 --- a/lib/Target/PTX/CMakeLists.txt +++ b/lib/Target/PTX/CMakeLists.txt @@ -7,3 +7,8 @@ add_mlir_translation_library(TritonPTX LINK_LIBS PUBLIC TritonLLVMIR ) + +set_source_files_properties( + PTXTranslation.cpp + PROPERTIES + COMPILE_FLAGS "-D__BUILD_DIR__=\\\"${CMAKE_BINARY_DIR}\\\"") diff --git a/lib/Target/PTX/PTXTranslation.cpp b/lib/Target/PTX/PTXTranslation.cpp index f089f1b16874..fe8841997c35 100644 --- a/lib/Target/PTX/PTXTranslation.cpp +++ b/lib/Target/PTX/PTXTranslation.cpp @@ -10,6 +10,11 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/IPO/AlwaysInliner.h" +#include + +#include +#include #include #include @@ -60,9 +65,14 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) { std::string layout = ""; std::string features = ""; // std::string features = "+ptx" + std::to_string(maxPTX); + for (llvm::Function &f : module.functions()) { + if (!f.hasFnAttribute(llvm::Attribute::NoInline)) + f.addFnAttr(llvm::Attribute::AlwaysInline); + } initLLVM(); // verify and store llvm llvm::legacy::PassManager pm; + pm.add(llvm::createAlwaysInlinerLegacyPass()); pm.add(llvm::createVerifierPass()); pm.run(module); // module->print(llvm::outs(), nullptr); diff --git a/python/pyproject.toml b/python/pyproject.toml index 6430c0c154dc..8bd8093e720d 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [build-system] -requires = ["setuptools>=40.8.0", "wheel", "cmake>=3.18"] +requires = ["setuptools>=40.8.0", "wheel", "cmake>=3.18", "ninja>=1.11.1"] [tool.autopep8] aggressive = 1 diff --git a/python/setup.py b/python/setup.py index 69479ceb13bb..ff0d19930cb8 100644 --- a/python/setup.py +++ b/python/setup.py @@ -68,12 +68,12 @@ def get_pybind11_package_info(): def get_llvm_package_info(): # added statement for Apple Silicon system = platform.system() - arch = 'x86_64' + arch = platform.machine() + if arch == 'aarch64': + arch = 'arm64' if system == "Darwin": system_suffix = "apple-darwin" - cpu_type = os.popen('sysctl machdep.cpu.brand_string').read() - if "apple" in cpu_type.lower(): - arch = 'arm64' + arch = platform.machine() elif system == "Linux": vglibc = tuple(map(int, platform.libc_ver()[1].split('.'))) vglibc = vglibc[0] * 100 + vglibc[1] @@ -86,6 +86,9 @@ def get_llvm_package_info(): name = f'llvm+mlir-17.0.0-{arch}-{system_suffix}-{release_suffix}' version = "llvm-17.0.0-c5dede880d17" url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/{version}/{name}.tar.xz" + # FIXME: remove the following once github.com/ptillet/triton-llvm-releases has arm64 llvm releases + if arch == 'arm64' and 'linux' in system_suffix: + url = f"https://github.com/acollins3/triton-llvm-releases/releases/download/{version}/{name}.tar.xz" return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH") @@ -126,7 +129,10 @@ def download_and_copy_ptxas(): base_dir = os.path.dirname(__file__) src_path = "bin/ptxas" version = "12.1.105" - url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-64/cuda-nvcc-{version}-0.tar.bz2" + arch = platform.machine() + if arch == "x86_64": + arch = "64" + url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2" dst_prefix = os.path.join(base_dir, "triton") dst_suffix = os.path.join("third_party", "cuda", src_path) dst_path = os.path.join(dst_prefix, dst_suffix) @@ -167,11 +173,15 @@ def __init__(self, name, path, sourcedir=""): class CMakeBuild(build_ext): - user_options = build_ext.user_options + [('base-dir=', None, 'base directory of Triton')] + user_options = build_ext.user_options + \ + [('base-dir=', None, 'base directory of Triton')] def initialize_options(self): build_ext.initialize_options(self) - self.base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) + self.base_dir = os.path.abspath( + os.path.join( + os.path.dirname(__file__), + os.pardir)) def finalize_options(self): build_ext.finalize_options(self) @@ -180,9 +190,7 @@ def run(self): try: out = subprocess.check_output(["cmake", "--version"]) except OSError: - raise RuntimeError( - "CMake must be installed to build the following extensions: " + ", ".join(e.name for e in self.extensions) - ) + raise RuntimeError("CMake must be installed to build the following extensions: " + ", ".join(e.name for e in self.extensions)) match = re.search(r"version\s*(?P\d+)\.(?P\d+)([\d.]+)?", out.decode()) cmake_major, cmake_minor = int(match.group("major")), int(match.group("minor")) @@ -202,6 +210,7 @@ def get_cmake_dir(self): def build_extension(self, ext): lit_dir = shutil.which('lit') + ninja_dir = shutil.which('ninja') user_home = os.getenv("HOME") or os.getenv("USERPROFILE") or \ os.getenv("HOMEPATH") or None if not user_home: @@ -216,6 +225,8 @@ def build_extension(self, ext): # python directories python_include_dir = sysconfig.get_path("platinclude") cmake_args = [ + "-G", "Ninja", # Ninja is much faster than make + "-DCMAKE_MAKE_PROGRAM=" + ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", "-DLLVM_ENABLE_WERROR=ON", "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, @@ -287,8 +298,9 @@ def build_extension(self, ext): "triton/third_party", "triton/tools", ], + long_description_content_type="text/markdown", install_requires=[ - "filelock", + "filelock" ], include_package_data=True, ext_modules=[CMakeExtension("triton", "triton/_C/")], @@ -311,7 +323,7 @@ def build_extension(self, ext): test_suite="tests", extras_require={ "build": [ - "cmake>=3.18", + "cmake>=3.20", "lit", ], "tests": [ @@ -321,11 +333,13 @@ def build_extension(self, ext): "numpy", "pytest", "scipy>=1.7.1", + "torch", ], "tutorials": [ "matplotlib", "pandas", "tabulate", + "torch", ], }, ) diff --git a/python/src/extra/cuda.ll b/python/src/extra/cuda.ll deleted file mode 100644 index 0ab2f6896bdd..000000000000 --- a/python/src/extra/cuda.ll +++ /dev/null @@ -1,17 +0,0 @@ -; ~/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/bin/llvm-as ./src/extra/cuda.ll -o ./triton/language/extra/cuda.bc - -target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" -target triple = "nvptx64-nvidia-cuda" - - -define i64 @globaltimer() #0 { - %1 = call i64 asm sideeffect "mov.u64 $0, %globaltimer;", "=l"() nounwind - ret i64 %1 -} - -define i32 @smid() #0 { - %1 = call i32 asm "mov.u32 $0, %smid;", "=r"() nounwind - ret i32 %1 -} - -attributes #0 = { alwaysinline nounwind } diff --git a/python/src/triton.cc b/python/src/triton.cc index 0d8a140feef0..29276493e8cd 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1,4 +1,8 @@ -#include "mlir/IR/Builders.h" +#include +#include +#include + +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" @@ -19,16 +23,21 @@ #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Analysis/Allocation.h" +#include "triton/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.h" #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" #include "triton/Target/HSACO/HSACOTranslation.h" #include "triton/Target/LLVMIR/LLVMIRTranslation.h" #include "triton/Target/PTX/PTXTranslation.h" #include "triton/Target/HSACO/HSACOTranslation.h" +#include "triton/Target/PTX/TmaMetadata.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "triton/Tools/Sys/GetPlatform.hpp" @@ -46,6 +55,7 @@ #include #include #include +#include #include #include #include @@ -58,6 +68,8 @@ namespace py = pybind11; +PYBIND11_MAKE_OPAQUE(mlir::triton::gpu::TMAMetadataTy); + enum backend_t { HOST, CUDA, @@ -71,6 +83,11 @@ void init_triton_runtime(py::module &&m) { .value("CUDA", CUDA) .value("ROCM", ROCM) .export_values(); + + py::enum_(m, "TARGET") + .value("NVVM", mlir::triton::NVVM) + .value("ROCDL", mlir::triton::ROCDL) + .export_values(); } // A custom op builder that keeps track of the last location @@ -157,6 +174,30 @@ class TritonOpBuilder { bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); }; +static std::string locationToString(mlir::Location loc) { + std::string str; + llvm::raw_string_ostream os(str); + loc.print(os); + os.flush(); // Make sure all the content is dumped into the 'str' string + return str; +} + +static void outputWarning(mlir::Location loc, const std::string &msg) { + std::string locStr = locationToString(loc); + + py::exec( + R"( +import warnings + +def custom_showwarning(message, category, filename, lineno, file=None, line=None): + print(f"UserWarning: {message}") + +warnings.showwarning = custom_showwarning +warnings.warn(f"{loc}: {msg}") +)", + py::globals(), py::dict(py::arg("loc") = locStr, py::arg("msg") = msg)); +} + /*****************************************************************************/ /* Python bindings for triton::ir */ /*****************************************************************************/ @@ -515,9 +556,11 @@ void init_triton_ir(py::module &&m) { mlir::DialectRegistry registry; registry.insert< mlir::triton::TritonDialect, mlir::triton::gpu::TritonGPUDialect, - mlir::math::MathDialect, mlir::arith::ArithDialect, - mlir::index::IndexDialect, mlir::scf::SCFDialect, - mlir::cf::ControlFlowDialect, mlir::LLVM::LLVMDialect>(); + mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect, + mlir::triton::nvgpu::NVGPUDialect, mlir::math::MathDialect, + mlir::arith::ArithDialect, mlir::index::IndexDialect, + mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect, + mlir::LLVM::LLVMDialect>(); context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); @@ -580,6 +623,17 @@ void init_triton_ir(py::module &&m) { newBlock->erase(); } }); + // 2. Check if the result of tl.advance is used + self.walk([&](mlir::Operation *op) { + if (mlir::isa(op) && + op->getResult(0).use_empty()) + outputWarning(op->getLoc(), "The result of tl.advance is not " + "being used. Note that tl.advance " + "does not have any side effects. " + "To move the block pointer, you " + "need to assign the result of " + "tl.advance to a variable."); + }); }) .def_property_readonly("type", &mlir::triton::FuncOp::getFunctionType) .def("reset_type", &mlir::triton::FuncOp::setType); @@ -753,7 +807,7 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self) -> mlir::Type { return self.getBuilder().getI64Type(); }) - .def("get_fp8e4_ty", + .def("get_fp8e4nv_ty", [](TritonOpBuilder &self) -> mlir::Type { return self.getBuilder().getType(); }) @@ -763,6 +817,12 @@ void init_triton_ir(py::module &&m) { // have a float-like type compatible with float only native ops return self.getBuilder().getType(); }) + .def("get_fp8e4b15x4_ty", + [](TritonOpBuilder &self) -> mlir::Type { + // TODO: upstream FP8E4B15 into MLIR, or find a way to externally + // have a float-like type compatible with float only native ops + return self.getBuilder().getType(); + }) .def("get_fp8e5_ty", [](TritonOpBuilder &self) -> mlir::Type { return self.getBuilder().getType(); @@ -1155,6 +1215,36 @@ void init_triton_ir(py::module &&m) { return mlir::Value(self.create(lhs, rhs)); #endif }) + .def("create_minsi", + [](TritonOpBuilder &self, mlir::Value &lhs, + mlir::Value &rhs) -> mlir::Value { + return mlir::Value(self.create(lhs, rhs)); + }) + .def("create_minui", + [](TritonOpBuilder &self, mlir::Value &lhs, + mlir::Value &rhs) -> mlir::Value { + return mlir::Value(self.create(lhs, rhs)); + }) + .def("create_minf", + [](TritonOpBuilder &self, mlir::Value &lhs, + mlir::Value &rhs) -> mlir::Value { + return mlir::Value(self.create(lhs, rhs)); + }) + .def("create_maxsi", + [](TritonOpBuilder &self, mlir::Value &lhs, + mlir::Value &rhs) -> mlir::Value { + return mlir::Value(self.create(lhs, rhs)); + }) + .def("create_maxui", + [](TritonOpBuilder &self, mlir::Value &lhs, + mlir::Value &rhs) -> mlir::Value { + return mlir::Value(self.create(lhs, rhs)); + }) + .def("create_maxf", + [](TritonOpBuilder &self, mlir::Value &lhs, + mlir::Value &rhs) -> mlir::Value { + return mlir::Value(self.create(lhs, rhs)); + }) // AddPtr (similar to GEP) .def("create_addptr", [](TritonOpBuilder &self, mlir::Value &ptr, @@ -1472,12 +1562,8 @@ void init_triton_ir(py::module &&m) { const std::string &libPath, const std::string &symbol, std::vector &argList, mlir::Type retType, bool isPure) -> mlir::Value { - if (isPure) - return self.create( - retType, argList, libName, libPath, symbol); - else - return self.create( - retType, argList, libName, libPath, symbol); + return self.create( + retType, argList, libName, libPath, symbol, isPure); }) // Built-in instruction .def("create_get_program_id", @@ -1572,6 +1658,14 @@ void init_triton_ir(py::module &&m) { return self.create(condition, trueValue, falseValue); }) + .def("create_inline_asm", + [](TritonOpBuilder &self, const std::string &inlineAsm, + const std::string &constraints, + const std::vector &values, mlir::Type &type, + bool isPure, int pack) -> mlir::Value { + return self.create( + type, inlineAsm, constraints, isPure, pack, values); + }) .def("create_print", [](TritonOpBuilder &self, const std::string &prefix, const std::vector &values) -> void { @@ -1652,6 +1746,11 @@ void init_triton_ir(py::module &&m) { .def( "add_sccp_pass", [](mlir::PassManager &self) { self.addPass(mlir::createSCCPPass()); }) + .def("add_plan_cta_pass", + [](mlir::PassManager &self, + mlir::triton::nvidia_gpu::ClusterInfo &clusterInfo) { + self.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&clusterInfo)); + }) .def("add_tritongpu_coalesce_pass", [](mlir::PassManager &self) { self.addPass(mlir::createTritonGPUCoalescePass()); @@ -1687,16 +1786,55 @@ void init_triton_ir(py::module &&m) { self.addPass(mlir::triton::createRewriteTensorPointerPass( computeCapability, isROCM)); }) + .def("add_tritongpu_ws_feasibility_checking_pass", + [](mlir::PassManager &self, int computeCapability) { + self.addPass(mlir::createTritonNvidiaGPUWSFeasibilityCheckingPass( + computeCapability)); + }) + .def("add_tritongpu_wsdecomposing_pass", + [](mlir::PassManager &self, int computeCapability) { + self.addPass(mlir::createTritonNvidiaGPUWSDecomposingPass( + computeCapability)); + }) + .def("add_tritongpu_wspipeline_pass", + [](mlir::PassManager &self, int numStages, int numWarps, + int computeCapability) { + self.addPass(mlir::createTritonNvidiaGPUWSPipelinePass( + numStages, numWarps, computeCapability)); + }) + .def("add_tritongpu_wsmutex_pass", + [](mlir::PassManager &self, int computeCapability) { + self.addPass( + mlir::createTritonNvidiaGPUWSMutexPass(computeCapability)); + }) + .def("add_tritongpu_wsmaterialization_pass", + [](mlir::PassManager &self, int computeCapability) { + self.addPass(mlir::createTritonNvidiaGPUWSMaterializationPass( + computeCapability)); + }) + .def("add_tritongpu_ws_fixup_missing_attrs_pass", + [](mlir::PassManager &self) { + self.addPass(mlir::createTritonNvidiaGPUWSFixupMissingAttrs()); + }) .def( "add_convert_triton_to_tritongpu_pass", - [](mlir::PassManager &self, int numWarps, int threadsPerWarp) { + [](mlir::PassManager &self, int numWarps, int threadsPerWarp, + int numCTAs, int computeCapability) { self.addPass(mlir::triton::createConvertTritonToTritonGPUPass( - numWarps, threadsPerWarp)); + numWarps, threadsPerWarp, numCTAs, computeCapability)); }, - py::arg("numWarps") = 4, py::arg("threadsPerWarp") = 32) + py::arg("numWarps") = 4, py::arg("threadsPerWarp") = 32, + py::arg("numCTAs") = 1, py::arg("computeCapability") = 80) .def("add_tritongpu_pipeline_pass", - [](mlir::PassManager &self, int numStages) { - self.addPass(mlir::createTritonGPUPipelinePass(numStages)); + [](mlir::PassManager &self, int numStages, int numWarps, int numCTAs, + int computeCapability) { + self.addPass(mlir::createTritonGPUPipelinePass( + numStages, numWarps, numCTAs, computeCapability)); + }) + .def("add_tritongpu_materialize_load_store_pass", + [](mlir::PassManager &self, int numWarps, int computeCapability) { + self.addPass(mlir::createTritonNvidiaGPUMaterializeLoadStorePass( + numWarps, computeCapability)); }) .def("add_tritongpu_stream_pipeline_pass", [](mlir::PassManager &self) { @@ -1723,26 +1861,96 @@ void init_triton_ir(py::module &&m) { [](mlir::PassManager &self) { self.addPass(mlir::createTritonGPUReorderInstructionsPass()); }) + .def("add_tritongpu_rewrite_tensor_pointer_pass", + [](mlir::PassManager &self, int computeCapability) { + self.addPass(mlir::createTritonGPURewriteTensorPointerPass( + computeCapability)); + }) .def("add_tritongpu_decompose_conversions_pass", [](mlir::PassManager &self) { self.addPass(mlir::createTritonGPUDecomposeConversionsPass()); }) + .def("add_tritongpu_fence_insertion_pass", + [](mlir::PassManager &self) { + self.addPass(mlir::createTritonNvidiaGPUFenceInsertionPass()); + }) .def("add_triton_gpu_to_llvm", [](mlir::PassManager &self) { self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass()); }) + .def("add_nv_gpu_to_llvm", + [](mlir::PassManager &self) { + self.addPass(mlir::triton::createConvertNVGPUToLLVMPass()); + }) .def("add_scf_to_cfg", [](mlir::PassManager &self) { self.addPass(mlir::createConvertSCFToCFPass()); }); + + m.def("is_ws_supported", [](mlir::ModuleOp &mod) -> bool { + return mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect::getWSSupportedAttr( + mod); + }); +} + +void init_triton_env_vars(py::module &m) { + m.def("get_env_vars", []() -> std::map { + std::map envVars; + for (const auto &envVar : triton::ENV_VARS) { + envVars[envVar] = triton::tools::getBoolEnv(envVar); + } + return envVars; + }); } void init_triton_translation(py::module &m) { using ret = py::return_value_policy; + py::class_(m, "ClusterInfo") + .def(py::init<>()) + .def_readwrite("clusterDimX", + &mlir::triton::nvidia_gpu::ClusterInfo::clusterDimX) + .def_readwrite("clusterDimY", + &mlir::triton::nvidia_gpu::ClusterInfo::clusterDimY) + .def_readwrite("clusterDimZ", + &mlir::triton::nvidia_gpu::ClusterInfo::clusterDimZ) + .def("__repr__", [](mlir::triton::nvidia_gpu::ClusterInfo &self) { + std::ostringstream oss; + oss << "(" << self.clusterDimX << ", " << self.clusterDimY << ", " + << self.clusterDimZ << ")"; + return oss.str(); + }); + + py::class_(m, "TMAInfo") + .def(py::init<>()) + .def_readwrite("tensorDataType", + &mlir::triton::gpu::TMAInfo::tensorDataType) + .def_readwrite("tensorRank", &mlir::triton::gpu::TMAInfo::tensorRank) + .def_readwrite("globalAddressArgIdx", + &mlir::triton::gpu::TMAInfo::globalAddressArgIdx) + .def_readwrite("globalStridesArgIdx", + &mlir::triton::gpu::TMAInfo::globalStridesArgIdx) + .def_readwrite("globalDimsArgIdx", + &mlir::triton::gpu::TMAInfo::globalDimsArgIdx) + .def_readwrite("boxDims", &mlir::triton::gpu::TMAInfo::boxDims) + .def_readwrite("elementStrides", + &mlir::triton::gpu::TMAInfo::elementStrides) + .def_readwrite("interleave", &mlir::triton::gpu::TMAInfo::interleave) + .def_readwrite("swizzle", &mlir::triton::gpu::TMAInfo::swizzle) + .def_readwrite("l2Promotion", &mlir::triton::gpu::TMAInfo::l2Promotion) + .def_readwrite("oobFill", &mlir::triton::gpu::TMAInfo::oobFill) + .def_readwrite("TMADescArgIdx", + &mlir::triton::gpu::TMAInfo::TMADescArgIdx); + py::bind_vector>(m, "TMAInfos"); + m.def("get_shared_memory_size", [](mlir::ModuleOp mod) { auto shared = mod->getAttrOfType("triton_gpu.shared"); return shared.getInt(); }); + m.def("get_num_warps", [](mlir::ModuleOp mod) { + auto shared = mod->getAttrOfType("triton_gpu.num-warps"); + assert(shared); + return shared.getInt(); + }); m.def( "get_arch_info", @@ -1760,11 +1968,13 @@ void init_triton_translation(py::module &m) { m.def( "translate_triton_gpu_to_llvmir", - [](mlir::ModuleOp op, int computeCapability, bool isROCM) { + [](mlir::ModuleOp op, int computeCapability, + mlir::triton::gpu::TMAMetadataTy &tmaInfos, + mlir::triton::Target target, int wavesPerEU) { py::gil_scoped_release allow_threads; llvm::LLVMContext llvmContext; auto llvmModule = ::mlir::triton::translateTritonGPUToLLVMIR( - &llvmContext, op, computeCapability, isROCM); + &llvmContext, op, computeCapability, tmaInfos, target, wavesPerEU); if (!llvmModule) llvm::report_fatal_error("Failed to translate TritonGPU to LLVM IR."); @@ -1890,6 +2100,7 @@ void init_triton_translation(py::module &m) { void init_triton(py::module &m) { py::module subm = m.def_submodule("triton"); + init_triton_env_vars(subm); // init_triton_codegen(subm.def_submodule("code_gen")); init_triton_runtime(subm.def_submodule("runtime")); init_triton_ir(subm.def_submodule("ir")); diff --git a/python/test/kernel_comparison/kernels.yml b/python/test/kernel_comparison/kernels.yml index 4cc732b63803..d557e6c6691b 100644 --- a/python/test/kernel_comparison/kernels.yml +++ b/python/test/kernel_comparison/kernels.yml @@ -1,31 +1,33 @@ name_and_extension: - - name: _kernel_0d1d2d34567c89c1011c + - name: _kernel_0d1d2d3de4de5de6c7de8de9c10de11c extension: ptx - - name: _kernel_0d1d2d3d4d5d6d7c8d9c10d11c + - name: _kernel_0d1d2d3de4de5de6de7c8de9c10de11c extension: ptx - - name: _kernel_0d1d2d3d4d5d6d7c8c9d10d11c + - name: _kernel_0d1d2d345de6c789c1011c extension: ptx - name: _kernel_0d1d2d3456c789c1011c extension: ptx - - name: _kernel_0d1d2d345d6d7c8c9d1011c + - name: _kernel_0d1d2d3de4de5de6c7de8c9de10de11c extension: ptx - name: _kernel_0d1d2d34567c8c91011c extension: ptx - name: _kernel_0d1d2d3456c78c91011c extension: ptx - - name: _kernel_0d1d2d345d6c78c9d1011c + - name: _kernel_0d1d2d3de4de5de6de7c8c9de10de11c + extension: ptx + - name: _kernel_0d1d2d34567c89c1011c extension: ptx - - name: _kernel_0d1d2d345d6c789c1011c + - name: _kernel_0d1d2d345de6de7c89c1011c extension: ptx - - name: _kernel_0d1d2d3d4d5d6c7d8d9c10d11c + - name: _kernel_0d1d2d345de6de7c8c9de1011c extension: ptx - - name: _kernel_0d1d2d3d4d5d6c7d8c9d10d11c + - name: kernel_0d1d2de extension: ptx - - name: _kernel_0d1d2d345d6d7c89c1011c + - name: _kernel_0d1d2d345de6c78c9de1011c extension: ptx - - name: _bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15d16c17d18d19d20c21d22d23d24c2526d27d + - name: _bwd_kernel_0d1d2d34d5d6d7d8d9d10d11de12de13de14de15c16de17de18de19c20de21de22de23c2425de26de extension: ptx - - name: _fwd_kernel_0d1d2d34d5d6d7d8d9d10c11d12d13d14c15d16d17d18c19d20d21d22c2324d25d + - name: _fwd_kernel_0d1d2d34d5d6de7de8de9c10de11de12de13c14de15de16de17c18de19de20de21c2223de24de extension: ptx - - name: _bwd_preprocess_0d1d2d3d4d + - name: _bwd_preprocess_0d1d2d extension: ptx diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index 2f6005f79da0..b22fea3e53a9 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -56,7 +56,7 @@ def nvsmi(attrs): (4096, 64, 4096): {'float16': 0.179, 'float32': 0.214, 'int8': 0.102}, (8192, 64, 8192): {'float16': 0.278, 'float32': 0.000, 'int8': 0.177}, # test EVEN_K==False - (8192, 8192, 8176): {'float16': 0.786, 'float32': 0.696, 'int8': 0.51}, + (8192, 8192, 8176): {'float16': 0.786, 'float32': 0.743, 'int8': 0.51}, } } @@ -64,7 +64,7 @@ def nvsmi(attrs): @pytest.mark.parametrize('M, N, K, dtype_str', [(M, N, K, dtype_str) for M, N, K in matmul_data[DEVICE_NAME].keys() - for dtype_str in ['float16', 'float32']]) + for dtype_str in ['float16']]) def test_matmul(M, N, K, dtype_str): stream = torch.cuda.Stream() torch.cuda.set_stream(stream) @@ -225,3 +225,59 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str): ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str)] print_perf(ms, cur_gpu_util, ref_gpu_util) triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01) + + +####################### +# Reduction +####################### + + +@triton.jit +def _sum(x_ptr, y_ptr, output_ptr, n_elements, + BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + # run in a loop to only to make it compute bound. + for i in range(100): + x = tl.sum(x, axis=0) + y + + tl.store(output_ptr + offsets, x, mask=mask) + + +reduction_data = { + 'a100': { + 1024 * 16384: {'float16': 0.016, 'float32': 0.031, 'int16': 0.015, 'int32': 0.031}, + 1024 * 65536: {'float16': 0.016, 'float32': 0.032, 'int16': 0.015, 'int32': 0.032}, + } +} + + +@pytest.mark.parametrize('N', reduction_data[DEVICE_NAME].keys()) +@pytest.mark.parametrize("dtype_str", ['float16', 'float32', 'int16', 'int32']) +def test_reductions(N, dtype_str): + stream = torch.cuda.Stream() + torch.cuda.set_stream(stream) + torch.manual_seed(0) + dtype = {'float16': torch.float16, 'float32': torch.float32, 'int16': torch.int16, 'int32': torch.int32}[dtype_str] + ref_gpu_util = reduction_data[DEVICE_NAME][N][dtype_str] + cur_sm_clock = nvsmi(['clocks.current.sm'])[0] + max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3) + z = torch.empty((N, ), dtype=dtype, device='cuda') + if dtype == torch.float16 or dtype == torch.float32: + x = torch.randn_like(z) + y = torch.randn_like(z) + else: + info = torch.iinfo(dtype) + x = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda') + y = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda') + grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) + fn = lambda: _sum[grid](x, y, z, N, BLOCK_SIZE=1024) + ms = triton.testing.do_bench_cudagraph(fn) + cur_gpu_perf = 100. * 2. * N / ms * 1e-9 + cur_gpu_util = cur_gpu_perf / max_gpu_perf + print_perf(ms, cur_gpu_util, ref_gpu_util) + triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01) diff --git a/python/test/tools/compare_files.py b/python/test/tools/compare_files.py index fe3688e91e04..1c8de084dcf9 100644 --- a/python/test/tools/compare_files.py +++ b/python/test/tools/compare_files.py @@ -9,9 +9,8 @@ class ComparisonResult: - def __init__(self, name: str, extension: str, numComparisons: int, diffs: List[str] = None, errors: List[str] = None): + def __init__(self, name: str, numComparisons: int, diffs: List[str] = None, errors: List[str] = None): self.name = name - self.extension = extension self.numComparisons = numComparisons self.diffs = [] if diffs is None else diffs self.errors = [] if errors is None else errors @@ -20,7 +19,7 @@ def isSuccess(self) -> bool: return len(self.diffs) == 0 and len(self.errors) == 0 def __str__(self) -> str: - return f"name={self.name}, extension={self.extension}, numComparisons={self.numComparisons}, success={self.isSuccess()}" + return f"name={self.name}, numComparisons={self.numComparisons}, success={self.isSuccess()}" def listFilesWithExtension(path: str, extension: str) -> List[str]: @@ -143,9 +142,9 @@ def doFilesMatch(path1: str, path2: str) -> bool: return True -def compareMatchingFiles(name: str, extension: str, nameToHashes1: Dict[str, List[str]], nameToHashes2: Dict[str, List[str]], args) -> ComparisonResult: +def compareMatchingFiles(name: str, nameToHashes1: Dict[str, List[str]], nameToHashes2: Dict[str, List[str]], args) -> ComparisonResult: """ - Compare files with the given name/extension in all hashes in both paths + Compare files with the given name in all hashes in both paths Return the first mismatching files as a tuple (file1, file2), otherwise, return an empty tuple """ hashes1 = nameToHashes1.get(name, []) @@ -164,14 +163,14 @@ def compareMatchingFiles(name: str, extension: str, nameToHashes1: Dict[str, Lis if not doFilesMatch(path1, path2): continue numComparisons += 1 - extFile1 = listFilesWithExtension(path1, extension)[0] - extFile2 = listFilesWithExtension(path2, extension)[0] + extFile1 = listFilesWithExtension(path1, "ptx")[0] + extFile2 = listFilesWithExtension(path2, "ptx")[0] diff = diffFiles(extFile1, extFile2) if len(diff) > 0: diffs.append(diffFiles(extFile2, extFile1)) if numComparisons == 0: errors.append(f"Did not find any matching files for {name}") - return ComparisonResult(name=name, extension=extension, numComparisons=numComparisons, diffs=diffs, errors=errors) + return ComparisonResult(name=name, numComparisons=numComparisons, diffs=diffs, errors=errors) def dumpResults(results: List[ComparisonResult], fileName: str): @@ -203,20 +202,15 @@ def main(args) -> bool: nameToHashes1 = getNameToHashesDict(args.path1) nameToHashes2 = getNameToHashesDict(args.path2) - yamlFilePath = args.kernels - if not os.path.exists(yamlFilePath): - print(f"Path {yamlFilePath} does not exist!") - sys.exit(2) - nameAndExtension = loadYamlFile(yamlFilePath)["name_and_extension"] + # Get all kernels that need to be checked + kernelNames = set(nameToHashes1.keys()).union(set(nameToHashes2.keys())) results = [] # iterate over the kernels that need to be checked - for d in nameAndExtension: - name = d["name"] # kernel name - extension = d["extension"] # extension of the file to be compared (e.g. ptx) + for name in kernelNames: # Compare all hashes on path 1 with all hashes on path 2 # result is either the mismatching (file1, file2) with "extension" or empty tuple if no mismatch - result = compareMatchingFiles(name, extension, nameToHashes1, nameToHashes2, args) + result = compareMatchingFiles(name, nameToHashes1, nameToHashes2, args) print(result) # Otherwise, add it to the mismatches results.append(result) @@ -250,12 +244,5 @@ def main(args) -> bool: required=True, help=("Path to second cache directory"), ) - parser.add_argument( - "--kernels", - type=str, - default=None, - required=True, - help=("Path to kernels yaml file"), - ) args = parser.parse_args() main(args) diff --git a/python/test/unit/hopper/__init__.py b/python/test/unit/hopper/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/test/unit/hopper/test_flashattention.py b/python/test/unit/hopper/test_flashattention.py new file mode 100644 index 000000000000..e46e1c1f2c59 --- /dev/null +++ b/python/test/unit/hopper/test_flashattention.py @@ -0,0 +1,480 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +""" +Fused Attention +=============== +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) +""" + +# import numpy as np +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel( + Q, K, V, sm_scale, + L, M, + Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, D0, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + + # TODO: may replace with TMA store without range offset + # initialize offsets for store + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_prev = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + stride_qh_2d = stride_qh // stride_qm // stride_qk + + q_tile_ptr = tl.make_block_ptr(base=Q, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=( + off_hz * stride_qh_2d + start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0)) + k_tile_ptr = tl.make_block_ptr(base=K, + shape=(D0, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(off_hz * stride_qh_2d, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0)) + v_tile_ptr = tl.make_block_ptr(base=V, + shape=(D0, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(off_hz * stride_qh_2d, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0)) + out_tile_ptr = tl.make_block_ptr(base=Out, + shape=(D0, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0)) + # load q: it will stay in SRAM throughout + q = tl.load(q_tile_ptr) + + # loop over k, v and update accumulators + for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): + # -- compute qk ---- + k = tl.load(k_tile_ptr, boundary_check=(0, 1)) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= ( + start_n + offs_n[None, :]), qk, float("-inf")) + # compute new m + m_curr = tl.maximum(tl.max(qk, 1), m_prev) + # correct old l + l_prev *= tl.exp(m_prev - m_curr) + # attention weights + p = tl.exp(qk - m_curr[:, None]) + l_curr = tl.sum(p, 1) + l_prev + # rescale operands of matmuls + l_rcp = 1. / l_curr + p *= l_rcp[:, None] + acc *= (l_prev * l_rcp)[:, None] + # update acc + p = p.to(tl.float16) + v = tl.load(v_tile_ptr, boundary_check=(0, 1)) + acc += tl.dot(p, v) + # update m_i and l_i + l_prev = l_curr + m_prev = m_curr + # update pointers + k_tile_ptr = tl.advance(k_tile_ptr, [BLOCK_N, 0]) + v_tile_ptr = tl.advance(v_tile_ptr, [BLOCK_N, 0]) + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + l_ptrs = L + off_hz * N_CTX + offs_m + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(l_ptrs, l_prev) + tl.store(m_ptrs, m_prev) + + acc = acc.to(tl.float16) + tl.store(out_tile_ptr, acc, boundary_check=(0, 1)) + + +@triton.jit +def _bwd_preprocess( + Out, DO, L, + NewDO, Delta, + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, +): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + denom = tl.load(L + off_m).to(tl.float32) + # compute + do = do / denom[:, None] + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) + tl.store(Delta + off_m, delta) + + +@triton.jit +def _bwd_kernel( + Q, K, V, sm_scale, Out, DO, + DQ, DK, DV, + L, M, + D, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + Z, H, N_CTX, D0, + num_block, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + # init tile_ptr + stride_qz_2d = stride_qz // stride_qm // stride_qk + stride_qh_2d = stride_qh // stride_qm // stride_qk + + q_tile_ptr = tl.make_block_ptr(base=Q, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=( + off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0)) + k_tile_ptr = tl.make_block_ptr(base=K, + shape=(D0, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=( + off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0)) + v_tile_ptr = tl.make_block_ptr(base=V, + shape=(D0, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=( + off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0)) + do_tile_ptr = tl.make_block_ptr(base=DO, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=( + off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0)) + dq_tile_ptr = tl.make_block_ptr(base=DQ, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=( + off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0)) + dk_tile_ptr = tl.make_block_ptr(base=DK, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=( + off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0)) + dv_tile_ptr = tl.make_block_ptr(base=DV, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=( + off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0)) + # offset pointers for batch/head + DQ += off_z * stride_qz + off_h * stride_qh + for start_n in range(0, num_block): + lo = start_n * BLOCK_M + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + m_ptrs = M + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_tile_ptr, boundary_check=(0, 1)) + v = tl.load(v_tile_ptr, boundary_check=(0, 1)) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_tile_ptr, boundary_check=(0, 1)) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + qk = tl.dot(q, tl.trans(k)) + qk = tl.where(offs_m_curr[:, None] >= ( + offs_n[None, :]), qk, float("-inf")) + m = tl.load(m_ptrs + offs_m_curr) + p = tl.exp(qk * sm_scale - m[:, None]) + # compute dv + do = tl.load(do_tile_ptr, boundary_check=(0, 1)) + dv += tl.dot(tl.trans(p.to(tl.float16)), do) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, tl.trans(v)) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds.to(tl.float16)), q) + # compute dq + dq = tl.load(dq_tile_ptr) + dq += tl.dot(ds.to(tl.float16), k) + tl.store(dq_tile_ptr, dq) + # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_tile_ptr = tl.advance(q_tile_ptr, [BLOCK_M, 0]) + do_tile_ptr = tl.advance(do_tile_ptr, [BLOCK_M, 0]) + dq_tile_ptr = tl.advance(dq_tile_ptr, [BLOCK_M, 0]) + q_tile_ptr = tl.advance(q_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0]) + do_tile_ptr = tl.advance(do_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0]) + dq_tile_ptr = tl.advance(dq_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0]) + # increment tile pointers + k_tile_ptr = tl.advance(k_tile_ptr, [BLOCK_M, 0]) + v_tile_ptr = tl.advance(v_tile_ptr, [BLOCK_M, 0]) + # write-back + tl.store(dv_tile_ptr, dv.to(tl.float16), boundary_check=(0, 1)) + tl.store(dk_tile_ptr, dk.to(tl.float16), boundary_check=(0, 1)) + dv_tile_ptr = tl.advance(dv_tile_ptr, [BLOCK_M, 0]) + dk_tile_ptr = tl.advance(dk_tile_ptr, [BLOCK_M, 0]) + + +empty = torch.empty(128, device="cuda") + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, sm_scale): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1) + L = torch.empty( + (q.shape[0] * q.shape[1], q.shape[2]), + device=q.device, + dtype=torch.float32) + m = torch.empty( + (q.shape[0] * q.shape[1], q.shape[2]), + device=q.device, + dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + D0 = q.shape[0] * q.shape[1] * q.shape[2] + _fwd_kernel[grid]( + q, k, v, sm_scale, + L, m, + o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], D0, + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + BLOCK_DMODEL=Lk, num_warps=num_warps, + num_stages=2, + ) + + ctx.save_for_backward(q, k, v, o, L, m) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + return o + + @staticmethod + def backward(ctx, do): + BLOCK = 128 + q, k, v, o, l, m = ctx.saved_tensors + do = do.contiguous() + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + do_scaled = torch.empty_like(do) + delta = torch.empty_like(l) + D0 = q.shape[0] * q.shape[1] * q.shape[2] + _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( + o, do, l, + do_scaled, delta, + BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, + ) + _bwd_kernel[(ctx.grid[1],)]( + q, k, v, ctx.sm_scale, + o, do_scaled, + dq, dk, dv, + l, m, + delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + q.shape[0], q.shape[1], q.shape[2], D0, + ctx.grid[0], + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, + num_stages=1, + ) + return dq, dk, dv, None + + +attention = _attention.apply + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 128, 64), + # (4, 48, 256, 64), + # (4, 48, 512, 64), + # (4, 48, 1024, 64), + # (4, 48, 2048, 64), + # (4, 48, 4096, 64), + # (4, 48, 8192, 64), out of memory + ]) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+") +def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + torch.manual_seed(20) + q = torch.empty( + (Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_( + mean=0.1, std=0.2).requires_grad_() + k = torch.empty( + (Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_( + mean=0.4, std=0.2).requires_grad_() + v = torch.empty( + (Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_( + mean=0.3, std=0.2).requires_grad_() + sm_scale = 0.2 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + for z in range(Z): + for h in range(H): + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + # p = torch.exp(p) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # triton implementation + tri_out = attention(q, k, v, sm_scale) + # print(ref_out) + # print(tri_out) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) + torch.testing.assert_close(ref_dq, tri_dq, atol=1e-2, rtol=0) + torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=0) + torch.testing.assert_close(ref_dk, tri_dk, atol=1e-2, rtol=0) + + +try: + from flash_attn.flash_attn_interface import flash_attn_func + HAS_FLASH = True +except BaseException: + HAS_FLASH = False + +BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +# vary seq length for fixed head and batch=4 +configs = [triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(10, 14)], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', + args={ + 'H': N_HEADS, + 'BATCH': BATCH, + 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode} +) for mode in ['fwd', 'bwd']] + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"): + assert mode in ['fwd', 'bwd'] + warmup = 25 + rep = 100 + if provider == "triton": + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + sm_scale = 1.3 + fn = lambda: attention(q, k, v, sm_scale) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + if provider == "flash": + lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros( + (BATCH + 1,), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) + fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +# only works on post-Ampere GPUs right now +# bench_flash_attention.run(save_path='.', print_data=True) diff --git a/python/test/unit/hopper/test_gemm.py b/python/test/unit/hopper/test_gemm.py new file mode 100644 index 000000000000..af236d0de3d8 --- /dev/null +++ b/python/test/unit/hopper/test_gemm.py @@ -0,0 +1,450 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import itertools +import os +import re + +import pytest +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + + +@triton.jit +def matmul_no_scf_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr +): + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) + a = tl.load(a_block_ptr) + b = tl.load(b_block_ptr) + + c = tl.dot(a, b) + + if FLOAT16_OUTPUT: + c = c.to(tl.float16) + + if USE_TMA_EPILOGUE: + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), + offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + tl.store(c_block_ptr, c) + else: + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, c) + + +@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,ENABLE_WS', + itertools.chain( + *[ + [ + # numCTAs = 1, no TMA multicast: + [64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + # static mask, cluster 4x1 + [256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + # dynamic mask, cluster 2x2 + [128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + # small M, N + [16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + ] for USE_TMA_EPILOGUE in [True, False] + for ENABLE_WS in [False, True] + ])) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") +def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, ENABLE_WS): + if (TRANS_A): + a = torch.randn((K, M), device='cuda', dtype=torch.float16).T + else: + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + if (TRANS_B): + b = torch.randn((N, K), device='cuda', dtype=torch.float16).T + else: + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + if OUTPUT_TYPE == "float16": + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + else: + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + + matmul_no_scf_kernel[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c, + M=M, N=N, K=K, + stride_am=a.stride(0), stride_ak=a.stride(1), + stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_cm=c.stride(0), stride_cn=c.stride(1), + BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, + num_warps=NUM_WARPS, + num_ctas=NUM_CTAS, + FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), + USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, + enable_warp_specialization=ENABLE_WS) + a_f32 = a.to(torch.float32) + b_f32 = b.to(torch.float32) + golden = torch.matmul(a_f32, b_f32) + torch.set_printoptions(profile="full") + assert_close( + c, + golden, + rtol=1e-2, + atol=1e-3, + check_dtype=False) + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_wm, stride_wn, + stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, + ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, + DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, + A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, + B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, + W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr, + Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr +): + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + block_offset_m = pid_m * BLOCK_M + block_offset_n = pid_n * BLOCK_N + + a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1)) + b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1)) + # for chain-dot, BLOCK_N must always be equal to N, and each program loads the whole W matrix + w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn), + offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_N), order=(W_ORDER_0, W_ORDER_1)) + z = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + offs_m = block_offset_m + tl.arange(0, BLOCK_M) + offs_n = block_offset_n + tl.arange(0, BLOCK_N) + z_ptrs = z_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn + bias_ptrs = bias_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn + mask = (offs_m < M)[:, None] & (offs_n < N)[None, :] + + for k in range(0, K, BLOCK_K): + a = tl.load(a_tile_ptr, boundary_check=(0, 1)) + b = tl.load(b_tile_ptr, boundary_check=(0, 1)) + z += tl.dot(a, b) + a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K]) + b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0]) + + z = z.to(out_dtype) + + if ADD_MATRIX: + z += tl.load(bias_ptrs, mask=mask) + if ADD_ROWS: + ZRs = bias_ptr + offs_m * stride_zm + z += tl.load(ZRs)[:, None] + if ADD_COLS: + ZCs = bias_ptr + offs_n * stride_zn + z += tl.load(ZCs)[None, :] + if DO_SOFTMAX: + max = tl.max(z, 1) + z = z - max[:, None] + num = tl.exp(z.to(tl.float32)).to(max.dtype) + den = tl.sum(num, 1) + z = num / den[:, None] + if CHAIN_DOT: + w = tl.load(w_tile_ptr) + z = tl.dot(z.to(w.dtype), w) + z = z.to(out_dtype) + + if USE_TMA_STORE: + z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn), + offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(Z_ORDER_0, Z_ORDER_1)) + tl.store(z_block_ptr, z, boundary_check=(0, 1)) + else: + tl.store(z_ptrs, z, mask=mask) + + +@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS', + [ + # corner shapes + (128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws) + for shape_w_c in [ + [4096, 1, 1024, False, False, True], + [2048, 204, 1000, True, False, True], + [4096, 1, 1024, False, False, False], + [2048, 204, 1000, True, False, False], + ] + for out_dtype in ['float16', 'float32'] + for use_tma_store in [False, True] + for enable_ws in [False, True] + ] + [ + # softmax epilogue + (*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) + for shape_w_c in [ + [64, 64, 16, 4, 1, 64, 64, 64], + [128, 128, 64, 4, 1, None, None, None], + [16, 16, 64, 4, 1, 16, 16, 64], + [64, 64, 32, 8, 1, 64, 64, 64], + [128, 128, 64, 4, 1, 128, 128, 128], + ] + for epilogue in ['softmax'] + for out_dtype in ['float16', 'float32'] + for use_tma_store in [False, True] + for trans_a in [False,] + for trans_b in [True,] + for trans_output in [False,] + for num_stages in [3] + for enable_ws in [False, True] + ] + [ + # loop over epilogues besides of softmax + (*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) + for shape_w_c in [ + [64, 64, 16, 4, 1, 128, 128, 64], + *[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], + # for chain-dot + [128, 128, 64, 4, 1, None, None, None], + [64, 64, 16, 4, 1, None, None, None], + # small BLOCK_M and BLOCK_K + [16, 16, 64, 4, 1, 128, 128, 64], + *[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], + # repeat + [64, 64, 32, 8, 1, 128, 256, 64], + [64, 64, 16, 8, 2, 128, 128, 64], + # irregular shape + [128, 128, 64, 4, 1, 500, 200, 128], + [128, 128, 64, 4, 2, 513, 193, 192], + ] + for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'] + for out_dtype in ['float16', 'float32'] + for use_tma_store in [False, True] + for trans_a in [False,] + for trans_b in [True,] + for trans_output in [False,] + for num_stages in [3] + for enable_ws in [False, True] + if not (epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) + ] + [ + # loop over tile shapes and transpose combinations + (*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws) + for shape_w_c in [ + [64, 64, 32, 4, 1, 128, 256, 64], + [128, 128, 16, 4, 4, 512, 256, 64], + [128, 256, 32, 4, 8, 256, 256, 192], + [512, 256, 32, 4, 8, 1024, 256, 192], + # BLOCK_K >= 128 + [64, 128, 128, 4, 1, 512, 256, 256], + [128, 128, 128, 4, 1, 256, 256, 192], + [128, 128, 128, 4, 2, 256, 256, 192], + # small BLOCK_M and BLOCK_K + [16, 32, 32, 4, 1, 128, 256, 64], + [32, 32, 16, 4, 1, 256, 256, 192], + [16, 32, 64, 4, 4, 512, 256, 64], + ] + for out_dtype in ['float32',] + for use_tma_store in [False,] + for trans_a in [False, True] + for trans_b in [False, True] + for trans_output in [False, True] + for num_stages in [3] + for enable_ws in [False, True] + ] + [ + # loop over instr shapes & pipeline stages + (64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws) + for n in [16, 32, 64, 128, 256] + for trans_output in [False,] + for out_dtype in ['float32',] + for use_tma_store in [False,] + for num_stages in [2, 4, 5, 7] + for enable_ws in [False, True] + ] + [ + # irregular shapes + (*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws) + for shape_w_c in [ + [128, 128, 64, 4, 1], + [256, 128, 64, 4, 2], + [128, 128, 128, 4, 2], + ] + for shape in [ + [512, 360, 1024], + [360, 4096, 512], + ] + for trans_output in [False,] + for out_dtype in ['float32',] + for use_tma_store in [False, True] + for num_stages in [3, 4] + for enable_ws in [False, True] + ]) +@pytest.mark.skipif(torch.cuda.get_device_capability() + [0] < 9, reason="Requires compute capability >= 9") +def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS): + if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ + '16-32-64-4-4-512-256-64-True-False', + '16-32-64-4-4-512-256-64-True-True', + '16-32-64-4-4-512-256-64-False-False', + '16-32-64-4-4-512-256-64-False-True', + ]: + pytest.skip('shapePerCTA[1] < 16 not supported') + + # with ENABLE_TMA=0 and ENABLE_MMA_V3=0 + if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ + '16-32-64-4-1-256-256-256-False', + '16-32-64-4-2-256-256-256-False', + '16-32-64-4-2-256-256-256-True', + '16-32-64-8-2-256-256-256-False', + '16-32-64-8-2-256-256-256-True', + ]: + pytest.skip('Known legacy issue, ldmatrix can only support x4') + + M = BLOCK_M if M is None else M + N = BLOCK_N if N is None else N + K = BLOCK_K if K is None else K + + if (TRANS_A): + a = torch.randn((K, M), device='cuda', dtype=torch.float16).T + a_order = [0, 1] + else: + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + a_order = [1, 0] + + if (TRANS_B): + b = torch.randn((N, K), device='cuda', dtype=torch.float16).T + b_order = [0, 1] + else: + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + b_order = [1, 0] + + if out_dtype == 'float16' and epilogue != 'softmax': + # TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will + # fail with the following error: 'llvm.fmul' op requires the same type + # for all operands and results + out_dtype = tl.float16 + torch_out_dtype = torch.float16 + else: + out_dtype = tl.float32 + torch_out_dtype = torch.float32 + + # avoid out of memory + if epilogue in ['add-matrix', 'add-rows', 'add-cols']: + if (TRANS_OUTPUT): + bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T + else: + bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) + else: + bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) + + # for chain-dot only + w = torch.randn((N, N), device='cuda', dtype=torch.float16).T + w_order = [0, 1] + + if (TRANS_OUTPUT): + z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T + z_order = [0, 1] + else: + z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) + z_order = [1, 0] + + # torch result + a_f32 = a.to(torch.float32) + b_f32 = b.to(torch.float32) + dot = torch.matmul(a_f32, b_f32) + + def process_epilogue(d, bias, w, epilogue): + if epilogue == 'add-matrix': + ref = d + bias + elif epilogue == 'add-rows': + ref = d + bias[:, 0][:, None] + elif epilogue == 'add-cols': + ref = d + bias[0, :][None, :] + elif epilogue == 'softmax': + num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) + denom = torch.sum(num, dim=-1, keepdims=True) + ref = num / denom + # ref = torch.softmax(d, 1) + elif epilogue == 'chain-dot': + ref = torch.matmul(d, w.to(torch.float32)) + else: + ref = d + return ref + golden = process_epilogue(dot, bias, w, epilogue) + + def grid(META): + return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) + pgm = matmul_kernel[grid](a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, + M=M, N=N, K=K, + stride_am=a.stride(0), stride_ak=a.stride(1), + stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_wm=w.stride(0), stride_wn=w.stride(1), + stride_zm=z.stride(0), stride_zn=z.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, + out_dtype=out_dtype, + USE_TMA_STORE=USE_TMA_STORE, + ADD_MATRIX=epilogue == 'add-matrix', + ADD_ROWS=epilogue == 'add-rows', + ADD_COLS=epilogue == 'add-cols', + DO_SOFTMAX=epilogue == 'softmax', + CHAIN_DOT=epilogue == 'chain-dot', + A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], + B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], + W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], + Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], + num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, + enable_warp_specialization=ENABLE_WS) + + torch.set_printoptions(profile="full") + golden = torch.nn.functional.normalize(golden) + z = torch.nn.functional.normalize(z) + assert_close(z, golden, + rtol=1e-2, + atol=1e-3, + check_dtype=False) + + enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower() + if enable_mmav3 in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: + ptx = pgm.asm['ptx'] + assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(BLOCK_N), ptx) diff --git a/python/test/unit/hopper/test_gemm_fusion.py b/python/test/unit/hopper/test_gemm_fusion.py new file mode 100644 index 000000000000..1fd53d5c4579 --- /dev/null +++ b/python/test/unit/hopper/test_gemm_fusion.py @@ -0,0 +1,166 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def gemm_fusion_kernel(A, B, C, E, + M, N, K, + stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr): + pid = tl.program_id(0) + + a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0)) + c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0)) + e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + + acc_e = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) + a = tl.load(a_tile_ptr) + for i in range(0, N, BLOCK_N): + b = tl.load(b_tile_ptr) + o_ab = tl.dot(a, tl.trans(b)) + c = tl.load(c_tile_ptr) + o_ab = o_ab.to(tl.float16) + acc_e += tl.dot(o_ab, c) + b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_N, 0]) + c_tile_ptr = tl.advance(c_tile_ptr, [BLOCK_N, 0]) + + acc_e = acc_e.to(tl.float16) + tl.store(e_tile_ptr, acc_e) + + +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="not passed on ampere") +def test_gemm_fusion(): + M, N, K = 4096, 4096, 64 + BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64 + A = torch.empty( + (M, K), dtype=torch.float16, device='cuda').normal_( + mean=0.1, std=0.2) + B = torch.empty( + (N, K), dtype=torch.float16, device='cuda').normal_( + mean=0.1, std=0.2) + C = torch.empty( + (N, K), dtype=torch.float16, device='cuda').normal_( + mean=0.1, std=0.2) + E = torch.empty((M, K), dtype=torch.float16, device='cuda') + ref_out = torch.matmul(torch.matmul(A, B.T), C) + num_warps = 4 + grid = (triton.cdiv(M, BLOCK_M), 1) + gemm_fusion_kernel[grid](A, B, C, E, M, N, K, + A.stride(0), A.stride(1), B.stride(0), B.stride( + 1), C.stride(0), C.stride(1), E.stride(0), E.stride(1), + BLOCK_M, BLOCK_N, BLOCK_K, num_warps=num_warps) + + torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0) + + +@triton.jit +def batched_gemm_fusion( + Q, K, V, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, NH, N_CTX, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + q_tile_ptr = tl.make_block_ptr(base=Q, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_qz, stride_qh, stride_qm, stride_qk), + offsets=(off_hz // NH, off_hz % NH, start_m, 0), + block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL), + order=(3, 2, 1, 0)) + k_tile_ptr = tl.make_block_ptr(base=K, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_kz, stride_kh, stride_kn, stride_kk), + offsets=(off_hz // NH, off_hz % NH, 0, 0), + block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL), + order=(3, 2, 1, 0)) + v_tile_ptr = tl.make_block_ptr(base=V, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_vz, stride_vh, stride_vk, stride_vn), + offsets=(off_hz // NH, off_hz % NH, 0, 0), + block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL), + order=(3, 2, 1, 0)) + o_tile_ptr = tl.make_block_ptr(base=Out, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_oz, stride_oh, stride_om, stride_on), + offsets=(off_hz // NH, off_hz % NH, start_m, 0), + block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL), + order=(3, 2, 1, 0)) + + q = tl.load(q_tile_ptr, boundary_check=(0, 1, 2, 3)) + q = tl.view(q, (BLOCK_M, BLOCK_DMODEL)) + for i in range(0, N_CTX, BLOCK_N): + k = tl.load(k_tile_ptr, boundary_check=(0, 1, 2, 3)) + k = tl.view(k, (BLOCK_N, BLOCK_DMODEL)) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + + p = qk.to(tl.float16) + v = tl.load(v_tile_ptr, boundary_check=(0, 1, 2, 3)) + v = tl.view(v, (BLOCK_N, BLOCK_DMODEL)) + acc += tl.dot(p, v) + + k_tile_ptr = tl.advance(k_tile_ptr, [0, 0, BLOCK_N, 0]) + v_tile_ptr = tl.advance(v_tile_ptr, [0, 0, BLOCK_N, 0]) + + acc = tl.view(acc, (1, 1, BLOCK_M, BLOCK_DMODEL)) + acc = acc.to(tl.float16) + tl.store(o_tile_ptr, acc) + + +@pytest.mark.skip(reason="don't support 4d across stack, left for future") +def test_batched_gemm_fusion(): + Z = 4 + NH = 48 + H = 64 + N_CTX = 2048 + BLOCK_M, BLOCK_N, BLOCK_DMODEL = 128, 128, H + torch.manual_seed(20) + A = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + B = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + C = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + E = torch.empty_like(A) + BT = B.transpose(-1, -2) + ref_out = torch.matmul(torch.matmul(A, BT), C) + num_warps = 4 + grid = (triton.cdiv(N_CTX, BLOCK_M), B * NH) + batched_gemm_fusion[grid](A, B, C, E, + A.stride(0), A.stride(1), A.stride(2), A.stride(3), + B.stride(0), B.stride(1), B.stride(2), B.stride(3), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + E.stride(0), E.stride(1), E.stride(2), E.stride(3), + Z, NH, N_CTX, + BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps) + + torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0) diff --git a/python/test/unit/hopper/test_mixed_io.py b/python/test/unit/hopper/test_mixed_io.py new file mode 100644 index 000000000000..cecabbaa732f --- /dev/null +++ b/python/test/unit/hopper/test_mixed_io.py @@ -0,0 +1,89 @@ +import pytest +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + +dtype_mapping = { + 'float16': torch.float16, + 'float32': torch.float32, +} + + +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x_block_ptr = tl.make_block_ptr( + base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + block_shape=(BLOCK_SIZE, ), order=(0, ) + ) + x = tl.load(x_block_ptr, boundary_check=(0, ), padding_option='zero') + + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + +@pytest.mark.parametrize('SIZE,BLOCK_SIZE,dtype_str', + [(98432, 1024, dtype_str) + for dtype_str in ['float16', 'float32'] + ]) +def test_add(SIZE, BLOCK_SIZE, dtype_str): + dtype = dtype_mapping[dtype_str] + output = torch.empty(SIZE, device='cuda', dtype=dtype) + x = torch.randn(SIZE, device='cuda', dtype=dtype) + y = torch.randn(SIZE, device='cuda', dtype=dtype) + + def grid(meta): + return (triton.cdiv(SIZE, meta['BLOCK_SIZE']),) + add_kernel[grid](x, y, output, SIZE, BLOCK_SIZE=BLOCK_SIZE) + + output_torch = x + y + torch.set_printoptions(profile='full') + assert_close(output, output_torch, rtol=1e-2, atol=1e-3, check_dtype=False) + + +@triton.jit +def load_reduce_kernel( + x_ptr, + y_ptr, + stride_xm, + stride_xn, + stride_y, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + x_ptr = tl.make_block_ptr( + base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn), + offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0) + ) + x = tl.load(x_ptr) + y = tl.max(x, axis=1) + tl.store(y_ptr + tl.arange(0, BLOCK_M), y) + + +@pytest.mark.parametrize('BLOCK_M,BLOCK_N,dtype_str', + [(128, 64, dtype_str) + for dtype_str in ['float16'] + ]) +def test_load_reduce(BLOCK_M, BLOCK_N, dtype_str): + dtype = dtype_mapping[dtype_str] + x = torch.randn((BLOCK_M, BLOCK_N), device='cuda', dtype=dtype) + y = torch.empty((BLOCK_M, ), device='cuda', dtype=dtype) + + load_reduce_kernel[(1,)](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N) + + golden = x.max(dim=1)[0] + torch.set_printoptions(profile='full') + assert_close(y, golden, rtol=1e-2, atol=1e-3, check_dtype=False) diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_fused-attention.py b/python/test/unit/hopper/test_persistent_warp_specialized_fused-attention.py new file mode 100644 index 000000000000..868c052d69a2 --- /dev/null +++ b/python/test/unit/hopper/test_persistent_warp_specialized_fused-attention.py @@ -0,0 +1,392 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +""" +Fused Attention +=============== +persistent warp specialized version of python/tutorials/06-fused-attention.py. +As of now, it only supports non-persistent warp specialized version of _fwd kernel. +""" + +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=4, num_stages=2, enable_warp_specialization=True), + ], + key=['Q', 'K', 'V'], +) +@triton.jit +def _fwd_kernel( + Q, K, V, sm_scale, + L, M, + Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk + off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + # initialize pointer to m and l + m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_prev = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + # loop over k, v and update accumulator + for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): + # -- compute qk ---- + k = tl.load(k_ptrs) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + # compute new m + m_curr = tl.maximum(tl.max(qk, 1), m_prev) + # correct old l + l_prev *= tl.exp(m_prev - m_curr) + # attention weights + p = tl.exp(qk - m_curr[:, None]) + l_curr = tl.sum(p, 1) + l_prev + # rescale operands of matmuls + l_rcp = 1. / l_curr + p *= l_rcp[:, None] + acc *= (l_prev * l_rcp)[:, None] + # update acc + p = p.to(Q.dtype.element_ty) + v = tl.load(v_ptrs) + acc += tl.dot(p, v) + # update m_i and l_i + l_prev = l_curr + m_prev = m_curr + # update pointers + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + l_ptrs = L + off_hz * N_CTX + offs_m + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(l_ptrs, l_prev) + tl.store(m_ptrs, m_prev) + # initialize pointers to output + offs_n = tl.arange(0, BLOCK_DMODEL) + off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + + +@triton.jit +def _bwd_preprocess( + Out, DO, L, + NewDO, Delta, + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, +): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + denom = tl.load(L + off_m).to(tl.float32) + # compute + do = do / denom[:, None] + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) + tl.store(Delta + off_m, delta) + + +@triton.jit +def _bwd_kernel( + Q, K, V, sm_scale, Out, DO, + DQ, DK, DV, + L, M, + D, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + Z, H, N_CTX, + num_block, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + # offset pointers for batch/head + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_qz + off_h * stride_qh + V += off_z * stride_qz + off_h * stride_qh + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DK += off_z * stride_qz + off_h * stride_qh + DV += off_z * stride_qz + off_h * stride_qh + for start_n in range(0, num_block): + lo = start_n * BLOCK_M + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + m_ptrs = M + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + qk = tl.dot(q, tl.trans(k)) + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + m = tl.load(m_ptrs + offs_m_curr) + p = tl.exp(qk * sm_scale - m[:, None]) + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, tl.trans(v)) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) + # compute dq + dq = tl.load(dq_ptrs) + dq += tl.dot(ds.to(Q.dtype.element_ty), k) + tl.store(dq_ptrs, dq) + # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + + +empty = torch.empty(128, device="cuda") + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, sm_scale): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + # Only support num_warps = 4 now + assert num_warps == 4 + + _fwd_kernel[grid]( + q, k, v, sm_scale, + L, m, + o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + BLOCK_DMODEL=Lk, + ) + + ctx.save_for_backward(q, k, v, o, L, m) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + return o + + @staticmethod + def backward(ctx, do): + BLOCK = 128 + q, k, v, o, l, m = ctx.saved_tensors + do = do.contiguous() + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + do_scaled = torch.empty_like(do) + delta = torch.empty_like(l) + _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( + o, do, l, + do_scaled, delta, + BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, + ) + _bwd_kernel[(ctx.grid[1],)]( + q, k, v, ctx.sm_scale, + o, do_scaled, + dq, dk, dv, + l, m, + delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + q.shape[0], q.shape[1], q.shape[2], + ctx.grid[0], + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, + num_stages=1, + ) + return dq, dk, dv, None + + +attention = _attention.apply + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)]) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") +def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() + sm_scale = 0.2 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + for z in range(Z): + for h in range(H): + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + # p = torch.exp(p) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # # triton implementation + tri_out = attention(q, k, v, sm_scale) + # print(ref_out) + # print(tri_out) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) + assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0) + assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0) + assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0) + + +try: + from flash_attn.flash_attn_interface import flash_attn_func + HAS_FLASH = True +except BaseException: + HAS_FLASH = False + +BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +# vary seq length for fixed head and batch=4 +configs = [triton.testing.Benchmark( + x_names=['N_CTX'], + # x_vals=[2**i for i in range(10, 14)], + x_vals=[2**i for i in range(10, 11)], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', + args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode} + # ) for mode in ['fwd', 'bwd']] +) for mode in ['fwd']] + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"): + assert mode in ['fwd', 'bwd'] + # warmup = 25 + # rep = 100 + warmup = 0 + rep = 1 + if provider == "triton": + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + sm_scale = 1.3 + fn = lambda: attention(q, k, v, sm_scale) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + if provider == "flash": + lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros( + (BATCH + 1,), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) + fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +# # only works on post-Ampere GPUs right now +# bench_flash_attention.run(save_path='.', print_data=True) diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py new file mode 100644 index 000000000000..fd7c14e6c85a --- /dev/null +++ b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py @@ -0,0 +1,930 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +import itertools + +import pytest +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + + +@triton.jit +def static_persistent_matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + NUM_SM: tl.constexpr, +): + start_tile = tl.program_id(axis=0) + m_tiles = tl.cdiv(M, BLOCK_M) + n_tiles = tl.cdiv(N, BLOCK_N) + num_tiles = m_tiles * n_tiles + offs_k = tl.arange(0, BLOCK_K) + + for tile_id in range(start_tile, num_tiles, NUM_SM): + pid_m = tile_id // n_tiles + pid_n = tile_id % n_tiles + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + for k in range(0, K, BLOCK_K): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + offs_cm = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offs_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + tl.store(c_ptrs, accumulator) + + +@triton.jit +def static_persistent_tma_matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + NUM_SM: tl.constexpr, +): + start_tile = tl.program_id(axis=0) + m_tiles = tl.cdiv(M, BLOCK_M) + n_tiles = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = m_tiles * n_tiles + + pre_pid_m = start_tile // n_tiles + pre_pid_n = start_tile % n_tiles + + block_offset_m = pre_pid_m * BLOCK_M + block_offset_n = pre_pid_n * BLOCK_N + a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) + for tile_id in range(start_tile, num_tiles, NUM_SM): + pid_m = tile_id // n_tiles + pid_n = tile_id % n_tiles + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + if tile_id >= NUM_SM: + a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -k_tiles * BLOCK_K]) + b_tile_ptr = tl.advance(b_tile_ptr, [-k_tiles * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N]) + + for k in range(0, K, BLOCK_K): + a = tl.load(a_tile_ptr) + b = tl.load(b_tile_ptr) + accumulator += tl.dot(a, b) + a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K]) + b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0]) + + offs_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offs_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, accumulator) + pre_pid_m = pid_m + pre_pid_n = pid_n + + +@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA', + [(*shape, use_tma) + for shape in [ + [4096, 4096, 64, 64, 64, 16, 4, 1, False, True], + [4096, 4096, 64, 64, 64, 32, 4, 1, False, True], + [4096, 4096, 64, 256, 64, 16, 4, 1, False, True], + [4096, 4096, 64, 128, 128, 16, 4, 1, False, True], + # TODO: fix issue for 8-warp persistent kernel + # [4096, 4096, 64, 128, 128, 16, 8, 1, False, True], + # [4096, 4096, 64, 128, 256, 16, 8, 1, False, True], + ] + for use_tma in [False, True] + ]) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") +def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA): + if (TRANS_A): + a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T + else: + a = .1 * torch.randn((M, K), device='cuda', dtype=torch.float16) + + if (TRANS_B): + b = .1 * torch.randn((N, K), device='cuda', dtype=torch.float16).T + else: + b = .1 * torch.randn((K, N), device='cuda', dtype=torch.float16) + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + + num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count + grid = lambda META: (num_SMs,) + + if USE_TMA: + static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS) + else: + static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS) + + th_c = torch.matmul(a, b) + torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0) + + +@triton.jit +def warp_specialized_matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +): + tid = tl.program_id(axis=0) + n_tiles = tl.cdiv(N, BLOCK_N) + pid_m = tid // n_tiles + pid_n = tid % n_tiles + + offs_k = tl.arange(0, BLOCK_K) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_am = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + offs_bn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k in range(0, K, BLOCK_K): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + accumulator = accumulator.to(c_ptr.dtype.element_ty) + + offs_cm = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offs_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :] + tl.store(c_ptrs, accumulator, mask=mask) + + +@triton.jit +def tma_warp_specialized_matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +): + tid = tl.program_id(axis=0) + n_tiles = tl.cdiv(N, BLOCK_N) + pid_m = tid // n_tiles + pid_n = tid % n_tiles + + block_offset_m = pid_m * BLOCK_M + block_offset_n = pid_n * BLOCK_N + a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k in range(0, K, BLOCK_K): + a = tl.load(a_tile_ptr) + b = tl.load(b_tile_ptr) + accumulator += tl.dot(a, b) + a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K]) + b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0]) + accumulator = accumulator.to(c_ptr.dtype.element_ty) + + offs_cm = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offs_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :] + tl.store(c_ptrs, accumulator, mask=mask) + + +@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA', + [(*shape, use_tma) + for shape in [ + [2048, 2048, 64, 64, 64, 16, 1, False, True], + [4096, 4096, 64, 64, 64, 16, 1, False, True], + [128, 4096, 64, 64, 64, 16, 1, False, True], + [4096, 128, 64, 64, 64, 16, 1, False, True], + [4096, 4096, 64, 64, 64, 32, 1, False, True], + [4096, 4096, 256, 128, 128, 16, 1, False, True], + [4096, 4096, 320, 128, 64, 64, 1, False, True], + [4096, 4096, 320, 64, 128, 64, 1, False, True], + [4096, 4096, 320, 128, 128, 64, 1, False, True], + [4096, 4096, 256, 256, 64, 16, 1, False, True], + [4096, 4096, 256, 256, 64, 64, 1, False, True], + [4096, 4096, 256, 64, 256, 16, 1, False, True], + [4096, 4096, 256, 64, 256, 64, 1, False, True], + [4096, 4096, 256, 256, 128, 16, 1, False, True], + [4096, 4096, 256, 256, 128, 64, 1, False, True], + [4096, 4096, 256, 128, 256, 16, 1, False, True], + [4096, 4096, 256, 128, 256, 64, 1, False, True], + # numCTAs > 1 + [2048, 2048, 64, 128, 128, 64, 2, False, True], + [2048, 2048, 128, 256, 128, 64, 4, False, True], + [4096, 4096, 128, 256, 128, 64, 4, False, True], + [4096, 4096, 256, 128, 256, 64, 4, False, True], + [4096, 4096, 256, 256, 256, 64, 4, False, True], + ] + for use_tma in [False, True] + ]) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") +def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA): + if (TRANS_A): + a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T + else: + a = .1 * torch.randn((M, K), device='cuda', dtype=torch.float16) + + if (TRANS_B): + b = .1 * torch.randn((N, K), device='cuda', dtype=torch.float16).T + else: + b = .1 * torch.randn((K, N), device='cuda', dtype=torch.float16) + + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + + grid = lambda META: (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + + if USE_TMA: + tma_warp_specialized_matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_M, BLOCK_N, BLOCK_K, + num_warps=4, + num_ctas=NUM_CTAS, + enable_warp_specialization=True) + else: + warp_specialized_matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_M, BLOCK_N, BLOCK_K, + num_warps=4, + num_ctas=NUM_CTAS, + enable_warp_specialization=True) + + th_c = torch.matmul(a, b) + torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0) + + +@triton.jit +def static_persistent_warp_specialized_matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + NUM_SM: tl.constexpr, +): + start_tile = tl.program_id(axis=0) + m_tiles = tl.cdiv(M, BLOCK_M) + n_tiles = tl.cdiv(N, BLOCK_N) + num_tiles = m_tiles * n_tiles + offs_k = tl.arange(0, BLOCK_K) + + for tile_id in range(start_tile, num_tiles, NUM_SM): + pid_m = tile_id // n_tiles + pid_n = tile_id % n_tiles + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + for k in range(0, K, BLOCK_K): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + offs_cm = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offs_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + tl.store(c_ptrs, accumulator) + + +@triton.jit +def static_persistent_tma_warp_specialized_matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + NUM_SM: tl.constexpr, +): + start_tile = tl.program_id(axis=0) + m_tiles = tl.cdiv(M, BLOCK_M) + n_tiles = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = m_tiles * n_tiles + + pre_pid_m = start_tile // n_tiles + pre_pid_n = start_tile % n_tiles + + block_offset_m = pre_pid_m * BLOCK_M + block_offset_n = pre_pid_n * BLOCK_N + a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) + for tile_id in range(start_tile, num_tiles, NUM_SM): + pid_m = tile_id // n_tiles + pid_n = tile_id % n_tiles + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + if tile_id >= NUM_SM: + a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -k_tiles * BLOCK_K]) + b_tile_ptr = tl.advance(b_tile_ptr, [-k_tiles * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N]) + + for k in range(0, K, BLOCK_K): + a = tl.load(a_tile_ptr) + b = tl.load(b_tile_ptr) + accumulator += tl.dot(a, b) + a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K]) + b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0]) + + offs_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offs_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, accumulator) + pre_pid_m = pid_m + pre_pid_n = pid_n + + +@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA', + [(*shape, use_tma) + for shape in [ + [2048, 2048, 64, 64, 64, 16, 1, False, True], + [4096, 4096, 64, 64, 64, 16, 1, False, True], + [128, 4096, 64, 64, 64, 16, 1, False, True], + [4096, 128, 64, 64, 64, 16, 1, False, True], + [4096, 4096, 64, 64, 64, 32, 1, False, True], + [4096, 4096, 256, 128, 128, 16, 1, False, True], + [4096, 4096, 320, 128, 64, 64, 1, False, True], + [4096, 4096, 320, 64, 128, 64, 1, False, True], + [4096, 4096, 320, 128, 128, 64, 1, False, True], + [4096, 4096, 256, 256, 64, 16, 1, False, True], + [4096, 4096, 256, 256, 64, 64, 1, False, True], + [4096, 4096, 256, 64, 256, 16, 1, False, True], + [4096, 4096, 256, 64, 256, 64, 1, False, True], + [4096, 4096, 256, 256, 128, 16, 1, False, True], + [4096, 4096, 256, 256, 128, 64, 1, False, True], + [4096, 4096, 256, 128, 256, 16, 1, False, True], + [4096, 4096, 256, 128, 256, 64, 1, False, True], + # numCTAs > 1 + [2048, 2048, 64, 128, 128, 64, 2, False, True], + [2048, 2048, 128, 256, 128, 64, 4, False, True], + [4096, 4096, 128, 256, 128, 64, 4, False, True], + [4096, 4096, 256, 128, 256, 64, 4, False, True], + [4096, 4096, 256, 256, 256, 64, 4, False, True], + ] + for use_tma in [False, True] + ]) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") +def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA): + if (TRANS_A): + a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T + else: + a = .1 * torch.randn((M, K), device='cuda', dtype=torch.float16) + + if (TRANS_B): + b = .1 * torch.randn((N, K), device='cuda', dtype=torch.float16).T + else: + b = .1 * torch.randn((K, N), device='cuda', dtype=torch.float16) + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + + num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count + grid = lambda META: (num_SMs,) + + if USE_TMA: + static_persistent_tma_warp_specialized_matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_M, BLOCK_N, BLOCK_K, num_SMs, + num_warps=4, num_ctas=NUM_CTAS, + enable_warp_specialization=True) + else: + static_persistent_warp_specialized_matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_M, BLOCK_N, BLOCK_K, num_SMs, + num_warps=4, num_ctas=NUM_CTAS, + enable_warp_specialization=True) + + th_c = torch.matmul(a, b) + torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0) + + +@triton.jit +def static_persistent_matmul_no_scf_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr, + NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr, +): + start_tile = tl.program_id(axis=0) + m_tiles = tl.cdiv(M, BLOCK_M) + n_tiles = tl.cdiv(N, BLOCK_N) + num_tiles = m_tiles * n_tiles + offs_k = tl.arange(0, BLOCK_K) + pre_pid_m = start_tile // n_tiles + pre_pid_n = start_tile % n_tiles + block_offset_m = pre_pid_m * BLOCK_M + block_offset_n = pre_pid_n * BLOCK_N + + if USE_TMA_LOAD: + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) + if USE_TMA_EPILOGUE: + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), + offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + + for tile_id in range(start_tile, num_tiles, NUM_SM): + pid_m = tile_id // n_tiles + pid_n = tile_id % n_tiles + + if USE_TMA_LOAD: + a_block_ptr = tl.advance(a_block_ptr, [(pid_m - pre_pid_m) * BLOCK_M, 0]) + b_block_ptr = tl.advance(b_block_ptr, [0, (pid_n - pre_pid_n) * BLOCK_N]) + a = tl.load(a_block_ptr) + b = tl.load(b_block_ptr) + else: + offs_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + + c = tl.dot(a, b) + + if FLOAT16_OUTPUT: + c = c.to(tl.float16) + + if USE_TMA_EPILOGUE: + c_block_ptr = tl.advance(c_block_ptr, [(pid_m - pre_pid_m) * BLOCK_M, (pid_n - pre_pid_n) * BLOCK_N]) + tl.store(c_block_ptr, c) + else: + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, c) + + pre_pid_m = pid_m + pre_pid_n = pid_n + + +@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,USE_TMA_LOAD', + itertools.chain( + *[ + [ + # numCTAs = 1, no TMA multicast: + [64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + # small M, N + [16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + ] for USE_TMA_EPILOGUE in [True, False] + for USE_TMA_LOAD in [True, False] + ])) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") +def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, USE_TMA_LOAD): + if (TRANS_A): + a = torch.randn((K, M), device='cuda', dtype=torch.float16).T + else: + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + if (TRANS_B): + b = torch.randn((N, K), device='cuda', dtype=torch.float16).T + else: + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + if OUTPUT_TYPE == "float16": + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + else: + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + + num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count + + # TODO: set `enable_warp_specialization=False` will lead to compilation error. + static_persistent_matmul_no_scf_kernel[(num_SMs,)](a_ptr=a, b_ptr=b, c_ptr=c, + M=M, N=N, K=K, + stride_am=a.stride(0), stride_ak=a.stride(1), + stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_cm=c.stride(0), stride_cn=c.stride(1), + BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs, + num_warps=NUM_WARPS, + num_ctas=NUM_CTAS, + FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), + USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, + USE_TMA_LOAD=USE_TMA_LOAD, + enable_warp_specialization=True) + a_f32 = a.to(torch.float32) + b_f32 = b.to(torch.float32) + golden = torch.matmul(a_f32, b_f32) + torch.set_printoptions(profile="full") + assert_close( + c, + golden, + rtol=1e-2, + atol=1e-3, + check_dtype=False) + + +@triton.jit +def full_static_persistent_matmul_kernel( + a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_wm, stride_wn, + stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, + ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, + DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, + A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, + B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, + NUM_SM: tl.constexpr +): + start_pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_tiles = num_pid_m * num_pid_n + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = start_pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pre_pid_m = first_pid_m + ((start_pid % num_pid_in_group) % group_size_m) + pre_pid_n = (start_pid % num_pid_in_group) // group_size_m + + pre_block_offset_m = pre_pid_m * BLOCK_M + pre_block_offset_n = pre_pid_n * BLOCK_N + a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + offsets=(pre_block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1)) + b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + offsets=(0, pre_block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1)) + w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn), + offsets=(0, pre_block_offset_n), block_shape=(BLOCK_N, BLOCK_N), order=(0, 1)) + + if USE_TMA_STORE: + z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn), + offsets=(pre_block_offset_m, pre_block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + + for tile_id in range(start_pid, num_tiles, NUM_SM): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + block_offset_m = pid_m * BLOCK_M + block_offset_n = pid_n * BLOCK_N + + offs_m = block_offset_m + tl.arange(0, BLOCK_M) + offs_n = block_offset_n + tl.arange(0, BLOCK_N) + z_ptrs = z_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn + bias_ptrs = bias_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn + mask = (offs_m < M)[:, None] & (offs_n < N)[None, :] + + # TODO: lib/Dialect/TritonGPU/Transforms/RewriteTensorPointer.cpp does not support scf.if yet. + # if tile_id >= NUM_SM: + # a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -tl.cdiv(K, BLOCK_K) * BLOCK_K]) + # b_tile_ptr = tl.advance(b_tile_ptr, [-tl.cdiv(K, BLOCK_K) * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N]) + + a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, 0]) + b_tile_ptr = tl.advance(b_tile_ptr, [0, (pid_n - pre_pid_n) * BLOCK_N]) + z = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, K, BLOCK_K): + a = tl.load(a_tile_ptr, boundary_check=(0, 1)) + b = tl.load(b_tile_ptr, boundary_check=(0, 1)) + z += tl.dot(a, b) + a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K]) + b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0]) + a_tile_ptr = tl.advance(a_tile_ptr, [0, -tl.cdiv(K, BLOCK_K) * BLOCK_K]) + b_tile_ptr = tl.advance(b_tile_ptr, [-tl.cdiv(K, BLOCK_K) * BLOCK_K, 0]) + + if (out_dtype == tl.constexpr(tl.float16)): + z = z.to(tl.float16) + + if ADD_MATRIX: + z += tl.load(bias_ptrs, mask=mask) + if ADD_ROWS: + ZRs = bias_ptr + offs_m * stride_zm + z += tl.load(ZRs)[:, None] + if ADD_COLS: + ZCs = bias_ptr + offs_n * stride_zn + z += tl.load(ZCs)[None, :] + if DO_SOFTMAX: + max = tl.max(z, 1) + z = z - max[:, None] + num = tl.exp(z.to(tl.float32)).to(max.dtype) + den = tl.sum(num, 1) + z = num / den[:, None] + if CHAIN_DOT: + w = tl.load(w_tile_ptr) + w_tile_ptr = tl.advance(w_tile_ptr, [0, (pid_n - pre_pid_n) * BLOCK_N]) + z = tl.dot(z.to(w.dtype), w) + if (out_dtype == tl.constexpr(tl.float16)): + z = z.to(tl.float16) + + if USE_TMA_STORE: + z_block_ptr = tl.advance(z_block_ptr, [(pid_m - pre_pid_m) * BLOCK_M, (pid_n - pre_pid_n) * BLOCK_N]) + tl.store(z_block_ptr, z, boundary_check=(0, 1)) + else: + tl.store(z_ptrs, z, mask=mask) + + pre_pid_m = pid_m + pre_pid_n = pid_n + + +@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS', + [ + # corner shapes + (128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws) + for shape_w_c in [ + [4096, 1, 1024, False, False], + [2048, 204, 1000, True, False], + [16, 524288, 32, False, True], + ] + for out_dtype in ['float16', 'float32'] + for use_tma_store in [False, True] + for enable_ws in [True] + ] + [ + # softmax epilogue + (*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) + # softmax works for one CTA + for shape_w_c in [ + [64, 64, 16, 4, 1, 64, 64, 64], + [128, 128, 64, 4, 1, None, None, None], + [16, 16, 64, 4, 1, 16, 16, 64], + # TODO: enable when num_warps != 4 is supported. + # [64, 64, 32, 8, 1, 64, 64, 64], + [128, 128, 64, 4, 1, 128, 128, 128], + ] + for epilogue in ['softmax'] + for out_dtype in ['float16', 'float32'] + for use_tma_store in [False, True] + for trans_a in [False,] + for trans_b in [True,] + for num_stages in [3] + for enable_ws in [True] + ] + [ + # loop over tile shapes and transpose combinations + (*shape_w_c, trans_a, trans_b, 'none', out_dtype, use_tma_store, num_stages, enable_ws) + for shape_w_c in [ + [64, 64, 32, 4, 1, 128, 256, 64], + [128, 128, 16, 4, 4, 512, 256, 64], + [128, 256, 32, 4, 8, 256, 256, 192], + [512, 256, 32, 4, 8, 1024, 256, 192], + # BLOCK_K >= 128 + [64, 128, 128, 4, 1, 512, 256, 256], + [128, 128, 128, 4, 1, 256, 256, 192], + [128, 128, 128, 4, 2, 256, 256, 192], + # small BLOCK_M and BLOCK_K + [16, 32, 32, 4, 1, 128, 256, 64], + [32, 32, 16, 4, 1, 256, 256, 192], + [16, 32, 64, 4, 4, 512, 256, 64], + ] + for out_dtype in ['float32',] + for use_tma_store in [False,] + for trans_a in [False, True] + for trans_b in [False, True] + for num_stages in [3] + for enable_ws in [True] + ] + [ + # loop over epilogues besides of softmax + (*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) + for shape_w_c in [ + [64, 64, 16, 4, 1, 128, 128, 64], + *[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1, 2, 4]], + # for chain-dot + [128, 128, 64, 4, 1, None, None, None], + [64, 64, 16, 4, 1, None, None, None], + # small BLOCK_M and BLOCK_K + [16, 16, 64, 4, 1, 128, 128, 64], + *[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1, 2]], + # # TODO: enable when num_warps != 4 is supported. + # # repeat + # # [64, 64, 32, 8, 1, 128, 256, 64], + # # [64, 64, 16, 8, 2, 128, 128, 64], + # irregular shape + [128, 128, 64, 4, 1, 500, 200, 128], + [128, 128, 64, 4, 1, 513, 193, 192], + ] + for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'] + for out_dtype in ['float16', 'float32'] + for use_tma_store in [False, True] + for trans_a in [False,] + for trans_b in [True,] + for num_stages in [3] + for enable_ws in [True] + if not (epilogue == 'chain-dot' and (shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1])) + ] + [ + # loop over instr shapes & pipeline stages + (64, n, 16, 4, 1, 512, 256, 256, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws) + for n in [16, 32, 64, 128, 256] + for out_dtype in ['float32'] + for use_tma_store in [False,] + for num_stages in [2, 4, 5, 7] + for enable_ws in [True] + ] + [ + # irregular shapes + (*shape_w_c, *shape, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws) + for shape_w_c in [ + [128, 128, 64, 4, 1], + [256, 128, 64, 4, 2], + [128, 128, 128, 4, 2] + ] + for shape in [ + [512, 360, 1024], + [360, 4096, 512], + ] + for out_dtype in ['float32'] + for use_tma_store in [False, True] + for num_stages in [3, 4] + for enable_ws in [True] + ] + ) +@pytest.mark.skipif(torch.cuda.get_device_capability() + [0] < 9, reason="Requires compute capability >= 9") +def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS): + if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS])) in [ + '128-128-128-4-1-256-256-192-none-float32-True-3-True', + ]: + pytest.skip('out of resource: shared memory, Required: 263168') + + if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ + '16-32-64-4-4-512-256-64-True-False', + '16-32-64-4-4-512-256-64-True-True', + '16-32-64-4-4-512-256-64-False-False', + '16-32-64-4-4-512-256-64-False-True', + ]: + pytest.skip('shapePerCTA[1] < 16 not supported') + + # with ENABLE_TMA=0 and ENABLE_MMA_V3=0 + if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ + '16-32-64-4-1-256-256-256-False', + '16-32-64-4-2-256-256-256-False', + '16-32-64-4-2-256-256-256-True', + '16-32-64-8-2-256-256-256-False', + '16-32-64-8-2-256-256-256-True', + ]: + pytest.skip('Known legacy issue, ldmatrix can only support x4') + + if epilogue == 'chain-dot': + pytest.skip('known failure: Assertion !region.empty() && unexpected empty region.') + + M = BLOCK_M if M is None else M + N = BLOCK_N if N is None else N + K = BLOCK_K if K is None else K + + if (TRANS_A): + a = torch.randn((K, M), device='cuda', dtype=torch.float16).T + a_order = [0, 1] + else: + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + a_order = [1, 0] + + if (TRANS_B): + b = torch.randn((N, K), device='cuda', dtype=torch.float16).T + b_order = [0, 1] + else: + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + b_order = [1, 0] + + if out_dtype == 'float16' and epilogue != 'softmax': + # TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will + # fail with the following error: 'llvm.fmul' op requires the same type + # for all operands and results + out_dtype = tl.float16 + torch_out_dtype = torch.float16 + else: + out_dtype = tl.float32 + torch_out_dtype = torch.float32 + + # avoid out of memory + if epilogue in ['add-matrix', 'add-rows', 'add-cols']: + bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) + else: + bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) + + if epilogue == 'chain-dot': + w = torch.randn((N, N), device='cuda', dtype=torch.float16).T + else: + w = torch.randn((1, 1), device='cuda', dtype=torch.float16).T + + z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) + + # torch result + a_f32 = a.to(torch.float32) + b_f32 = b.to(torch.float32) + dot = torch.matmul(a_f32, b_f32) + + def process_epilogue(d, bias, w, epilogue): + if epilogue == 'add-matrix': + ref = d + bias + elif epilogue == 'add-rows': + ref = d + bias[:, 0][:, None] + elif epilogue == 'add-cols': + ref = d + bias[0, :][None, :] + elif epilogue == 'softmax': + num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) + denom = torch.sum(num, dim=-1, keepdims=True) + ref = num / denom + # ref = torch.softmax(d, 1) + elif epilogue == 'chain-dot': + ref = torch.matmul(d, w.to(torch.float32)) + else: + ref = d + return ref + golden = process_epilogue(dot, bias, w, epilogue) + + num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count + + def grid(META): + return (num_SMs,) + full_static_persistent_matmul_kernel[grid]( + a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, + M=M, N=N, K=K, + stride_am=a.stride(0), stride_ak=a.stride(1), + stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_wm=w.stride(0), stride_wn=w.stride(1), + stride_zm=z.stride(0), stride_zn=z.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, + out_dtype=out_dtype, + USE_TMA_STORE=USE_TMA_STORE, + ADD_MATRIX=epilogue == 'add-matrix', + ADD_ROWS=epilogue == 'add-rows', + ADD_COLS=epilogue == 'add-cols', + DO_SOFTMAX=epilogue == 'softmax', + CHAIN_DOT=epilogue == 'chain-dot', + A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], + B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], + num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, + enable_warp_specialization=ENABLE_WS, + NUM_SM=num_SMs) + + torch.set_printoptions(profile="full") + golden = torch.nn.functional.normalize(golden) + z = torch.nn.functional.normalize(z) + assert_close(z, golden, + rtol=1e-2, + atol=1e-3, + check_dtype=False) diff --git a/python/test/unit/hopper/test_tma_store_gemm.py b/python/test/unit/hopper/test_tma_store_gemm.py new file mode 100644 index 000000000000..6d912d89caed --- /dev/null +++ b/python/test/unit/hopper/test_tma_store_gemm.py @@ -0,0 +1,92 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +import pytest +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + + +@triton.jit +def matmul_tma_load_store( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + OUTPUT_F16: tl.constexpr +): + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), + offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + a = tl.load(a_block_ptr) + b = tl.load(b_block_ptr) + + c = tl.dot(a, b) + if OUTPUT_F16: + c = c.to(tl.float16) + + tl.store(c_block_ptr, c) + + +@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_F16', [ + [64, 64, 16, 1, 4, False, True, False], + [64, 64, 16, 1, 4, False, True, True], + [128, 64, 32, 1, 4, False, True, False], + [128, 64, 32, 1, 4, False, True, True], + [64, 128, 32, 1, 4, False, True, False], + [64, 128, 32, 1, 4, False, True, True], + [128, 128, 64, 1, 4, False, True, False], + [128, 128, 64, 1, 4, False, True, True], +]) +def test_tma_load_store(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_F16): + if (TRANS_A): + a = torch.randn((K, M), device='cuda', dtype=torch.float16).T + else: + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + if (TRANS_B): + b = torch.randn((N, K), device='cuda', dtype=torch.float16).T + else: + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + if OUTPUT_F16: + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + + matmul_tma_load_store[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c, + M=M, N=N, K=K, + stride_am=a.stride(0), stride_ak=a.stride(1), + stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_cm=c.stride(0), stride_cn=c.stride(1), + BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, + num_warps=NUM_WARPS, + num_ctas=NUM_CTAS, + OUTPUT_F16=OUTPUT_F16) + golden = torch.matmul(a, b) + torch.set_printoptions(profile="full") + assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False) diff --git a/python/test/unit/hopper/ttgir_tests/test_tma.py b/python/test/unit/hopper/ttgir_tests/test_tma.py new file mode 100644 index 000000000000..d48d2aa42986 --- /dev/null +++ b/python/test/unit/hopper/ttgir_tests/test_tma.py @@ -0,0 +1,70 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import os + +import pytest +import torch +from torch.testing import assert_close + +import triton + + +@pytest.mark.parametrize('TTGIR,TRANS_A,TRANS_B', [ + # TODO: uncomment when it's done + # ["wgmma_tma_64_64_16_f16.ttgir", False, True], +]) +def test_tma_wgmma_64_64_16_f16(TTGIR, TRANS_A, TRANS_B): + capability = torch.cuda.get_device_capability() + if capability[0] < 9: + pytest.skip("Only test wgmma on devices with sm >= 90") + + SIZE_M = 64 + SIZE_N = 64 + SIZE_K = 16 + if (TRANS_A): + a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T + else: + a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16) + + if (TRANS_B): + b = torch.randn((SIZE_N, SIZE_K), device='cuda', dtype=torch.float16).T + else: + b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16) + + c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32) + + ttgir_path = os.path.dirname(__file__) + "/" + TTGIR + kernel = triton.compile(ttgir_path) + kernel[(1, 1, 1)](a.data_ptr(), b.data_ptr(), c.data_ptr(), + SIZE_M, SIZE_N, SIZE_K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0)) + + golden = torch.matmul(a, b) + torch.set_printoptions(profile="full", sci_mode=False) + assert_close( + c, + golden, + rtol=1e-2, + atol=1e-3, + check_dtype=False) diff --git a/python/test/unit/hopper/ttgir_tests/wgmma_64_64_16_f16_NT.ttgir b/python/test/unit/hopper/ttgir_tests/wgmma_64_64_16_f16_NT.ttgir new file mode 100644 index 000000000000..ec3dfb8f19b1 --- /dev/null +++ b/python/test/unit/hopper/ttgir_tests/wgmma_64_64_16_f16_NT.ttgir @@ -0,0 +1,52 @@ +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>> + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0> + %3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1> + %4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0> + %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked0> + %6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0> + %8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0> + %9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2> + %11 = tt.splat %arg4 : (i32) -> tensor<16x1xi32, #blocked2> + %12 = tt.splat %arg1 : (!tt.ptr) -> tensor<16x1x!tt.ptr, #blocked2> + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2> + %16 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2> + %17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0> + %18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr, #blocked0>, tensor<64x1xi32, #blocked0> + %19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr, #blocked0>) -> tensor<64x16x!tt.ptr, #blocked0> + %20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr, #blocked0>, tensor<64x16xi32, #blocked0> + %21 = tt.load %20 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked0> + %22 = arith.muli %10, %11 : tensor<16x1xi32, #blocked2> + %23 = tt.addptr %12, %22 : tensor<16x1x!tt.ptr, #blocked2>, tensor<16x1xi32, #blocked2> + %24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr, #blocked2>) -> tensor<16x64x!tt.ptr, #blocked2> + %25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr, #blocked2>, tensor<16x64xi32, #blocked2> + %26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2> + %27 = triton_gpu.convert_layout %21 : (tensor<64x16xf16, #blocked0>) -> tensor<64x16xf16, #shared0> + %28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1> + %29 = tt.dot %27, %28, %cst {allowTF32 = true, transA = true, transB = true} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1> + %31 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> + %32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> + %33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1> + %34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1> + %35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> + %37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1> + tt.store %37, %38 : tensor<64x64xf32, #blocked1> + return + } +} diff --git a/python/test/unit/hopper/ttgir_tests/wgmma_64_64_16_f16_TN.ttgir b/python/test/unit/hopper/ttgir_tests/wgmma_64_64_16_f16_TN.ttgir new file mode 100644 index 000000000000..3e3aae12c814 --- /dev/null +++ b/python/test/unit/hopper/ttgir_tests/wgmma_64_64_16_f16_TN.ttgir @@ -0,0 +1,52 @@ +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>> + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0> + %3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1> + %4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0> + %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked0> + %6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0> + %8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0> + %9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2> + %11 = tt.splat %arg4 : (i32) -> tensor<16x1xi32, #blocked2> + %12 = tt.splat %arg1 : (!tt.ptr) -> tensor<16x1x!tt.ptr, #blocked2> + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2> + %16 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2> + %17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0> + %18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr, #blocked0>, tensor<64x1xi32, #blocked0> + %19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr, #blocked0>) -> tensor<64x16x!tt.ptr, #blocked0> + %20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr, #blocked0>, tensor<64x16xi32, #blocked0> + %21 = tt.load %20 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked0> + %22 = arith.muli %10, %11 : tensor<16x1xi32, #blocked2> + %23 = tt.addptr %12, %22 : tensor<16x1x!tt.ptr, #blocked2>, tensor<16x1xi32, #blocked2> + %24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr, #blocked2>) -> tensor<16x64x!tt.ptr, #blocked2> + %25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr, #blocked2>, tensor<16x64xi32, #blocked2> + %26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2> + %27 = triton_gpu.convert_layout %21 : (tensor<64x16xf16, #blocked0>) -> tensor<64x16xf16, #shared0> + %28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1> + %29 = tt.dot %27, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1> + %31 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> + %32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> + %33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1> + %34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1> + %35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> + %37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1> + tt.store %37, %38 : tensor<64x64xf32, #blocked1> + return + } +} diff --git a/python/test/unit/hopper/ttgir_tests/wgmma_a_ldgsts_64_64_16_f16.ttgir b/python/test/unit/hopper/ttgir_tests/wgmma_a_ldgsts_64_64_16_f16.ttgir new file mode 100644 index 000000000000..b58dbd9212e9 --- /dev/null +++ b/python/test/unit/hopper/ttgir_tests/wgmma_a_ldgsts_64_64_16_f16.ttgir @@ -0,0 +1,59 @@ +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>> + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0> + %3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1> + %4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0> + %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked0> + %6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0> + %8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0> + %9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2> + %11 = tt.splat %arg4 : (i32) -> tensor<16x1xi32, #blocked2> + %12 = tt.splat %arg1 : (!tt.ptr) -> tensor<16x1x!tt.ptr, #blocked2> + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2> + %16 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2> + %17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0> + %18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr, #blocked0>, tensor<64x1xi32, #blocked0> + %19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr, #blocked0>) -> tensor<64x16x!tt.ptr, #blocked0> + %20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr, #blocked0>, tensor<64x16xi32, #blocked0> + %ci0 = arith.constant 0 : i32 + %i1_true = arith.constant 1 : i1 + %i1_false = arith.constant 0 : i1 + %t = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0> + %mask0 = tt.splat %i1_true : (i1) -> tensor<64x16xi1, #blocked0> + %t0 = triton_gpu.insert_slice_async %20, %t, %ci0, %mask0 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<64x16x!tt.ptr, #blocked0> -> tensor<1x64x16xf16, #shared0> + triton_gpu.async_commit_group + triton_gpu.async_wait {num = 1 : i32} + %21 = triton_gpu.extract_slice %t0[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0> + %22 = arith.muli %10, %11 : tensor<16x1xi32, #blocked2> + %23 = tt.addptr %12, %22 : tensor<16x1x!tt.ptr, #blocked2>, tensor<16x1xi32, #blocked2> + %24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr, #blocked2>) -> tensor<16x64x!tt.ptr, #blocked2> + %25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr, #blocked2>, tensor<16x64xi32, #blocked2> + %26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2> + %28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1> + %29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1> + %31 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> + %32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> + %33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1> + %34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1> + %35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> + %37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1> + tt.store %37, %38 : tensor<64x64xf32, #blocked1> + return + } +} diff --git a/python/test/unit/hopper/ttgir_tests/wgmma_a_ldgsts_mbarrier_64_64_16_f16.ttgir b/python/test/unit/hopper/ttgir_tests/wgmma_a_ldgsts_mbarrier_64_64_16_f16.ttgir new file mode 100644 index 000000000000..d44d7f879b45 --- /dev/null +++ b/python/test/unit/hopper/ttgir_tests/wgmma_a_ldgsts_mbarrier_64_64_16_f16.ttgir @@ -0,0 +1,63 @@ +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>> + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0> + %3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1> + %4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0> + %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked0> + %6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0> + %8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0> + %9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2> + %11 = tt.splat %arg4 : (i32) -> tensor<16x1xi32, #blocked2> + %12 = tt.splat %arg1 : (!tt.ptr) -> tensor<16x1x!tt.ptr, #blocked2> + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2> + %16 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2> + %17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0> + %18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr, #blocked0>, tensor<64x1xi32, #blocked0> + %19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr, #blocked0>) -> tensor<64x16x!tt.ptr, #blocked0> + %20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr, #blocked0>, tensor<64x16xi32, #blocked0> + %ci0 = arith.constant 0 : i32 + %i1_true = arith.constant 1 : i1 + %i1_false = arith.constant 0 : i1 + %t = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0> + // TODO: even an empty init external call here will break the UT + %mbar = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : tensor<1xi64, #shared0> + // %mbar0 = triton_nvidia_gpu.extract_mbarrier %mbar[%ci0] : tensor<1xi64, #shared0>, i32 -> !tt.ptr + %mask0 = tt.splat %i1_true : (i1) -> tensor<64x16xi1, #blocked0> + %t0 = triton_gpu.insert_slice_async %20, %t, %ci0, %mask0 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<64x16x!tt.ptr, #blocked0> -> tensor<1x64x16xf16, #shared0> + triton_gpu.async_wait {num = 1 : i32} + // triton_nvidia_gpu.mbarrier_arrive %mbar0 {trackAsyncOp = true} : !tt.ptr + // triton_nvidia_gpu.mbarrier_wait %mbar0, %i1_false : !tt.ptr + %21 = triton_gpu.extract_slice %t0[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0> + %22 = arith.muli %10, %11 : tensor<16x1xi32, #blocked2> + %23 = tt.addptr %12, %22 : tensor<16x1x!tt.ptr, #blocked2>, tensor<16x1xi32, #blocked2> + %24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr, #blocked2>) -> tensor<16x64x!tt.ptr, #blocked2> + %25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr, #blocked2>, tensor<16x64xi32, #blocked2> + %26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2> + %28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1> + %29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1> + %31 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> + %32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> + %33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1> + %34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1> + %35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> + %37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1> + tt.store %37, %38 : tensor<64x64xf32, #blocked1> + return + } +} diff --git a/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_64_64_16_f16.ttgir b/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_64_64_16_f16.ttgir new file mode 100644 index 000000000000..3ee491653344 --- /dev/null +++ b/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_64_64_16_f16.ttgir @@ -0,0 +1,63 @@ +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>> + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0> + %3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1> + %4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0> + %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked0> + %6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0> + %8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0> + %9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2> + %11 = tt.splat %arg4 : (i32) -> tensor<1x64xi32, #blocked2> + %12 = tt.splat %arg1 : (!tt.ptr) -> tensor<16x1x!tt.ptr, #blocked2> + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2> + %bstride1 = arith.muli %11, %15 : tensor<1x64xi32, #blocked2> + %16 = tt.broadcast %bstride1 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2> + %17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0> + %18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr, #blocked0>, tensor<64x1xi32, #blocked0> + %19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr, #blocked0>) -> tensor<64x16x!tt.ptr, #blocked0> + %20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr, #blocked0>, tensor<64x16xi32, #blocked0> + %ci0 = arith.constant 0 : i32 + %i1_true = arith.constant 1 : i1 + %i1_false = arith.constant 0 : i1 + %at = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0> + %maska = tt.splat %i1_true : (i1) -> tensor<64x16xi1, #blocked0> + %at0 = triton_gpu.insert_slice_async %20, %at, %ci0, %maska {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<64x16x!tt.ptr, #blocked0> -> tensor<1x64x16xf16, #shared0> + triton_gpu.async_commit_group + triton_gpu.async_wait {num = 1 : i32} + %21 = triton_gpu.extract_slice %at0[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0> + %23 = tt.addptr %12, %10 : tensor<16x1x!tt.ptr, #blocked2>, tensor<16x1xi32, #blocked2> + %24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr, #blocked2>) -> tensor<16x64x!tt.ptr, #blocked2> + %25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr, #blocked2>, tensor<16x64xi32, #blocked2> + %bt = triton_gpu.alloc_tensor : tensor<1x16x64xf16, #shared1> + %maskb = tt.splat %i1_true : (i1) -> tensor<16x64xi1, #blocked2> + %bt0 = triton_gpu.insert_slice_async %25, %bt, %ci0, %maskb {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<16x64x!tt.ptr, #blocked2> -> tensor<1x16x64xf16, #shared1> + triton_gpu.async_commit_group + triton_gpu.async_wait {num = 1 : i32} + %28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1> + %29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1> + %31 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> + %32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> + %33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1> + %34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1> + %35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> + %37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1> + tt.store %37, %38 : tensor<64x64xf32, #blocked1> + return + } +} diff --git a/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_mbarrier_64_64_16_f16.ttgir b/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_mbarrier_64_64_16_f16.ttgir new file mode 100644 index 000000000000..3f175c2c414c --- /dev/null +++ b/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_mbarrier_64_64_16_f16.ttgir @@ -0,0 +1,67 @@ +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>> + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0> + %3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1> + %4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0> + %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked0> + %6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0> + %8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0> + %9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2> + %11 = tt.splat %arg4 : (i32) -> tensor<1x64xi32, #blocked2> + %12 = tt.splat %arg1 : (!tt.ptr) -> tensor<16x1x!tt.ptr, #blocked2> + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2> + %bstride1 = arith.muli %11, %15 : tensor<1x64xi32, #blocked2> + %16 = tt.broadcast %bstride1 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2> + %17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0> + %18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr, #blocked0>, tensor<64x1xi32, #blocked0> + %19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr, #blocked0>) -> tensor<64x16x!tt.ptr, #blocked0> + %20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr, #blocked0>, tensor<64x16xi32, #blocked0> + %ci0 = arith.constant 0 : i32 + %i1_true = arith.constant 1 : i1 + %i1_false = arith.constant 0 : i1 + %at = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0> + %maska = tt.splat %i1_true : (i1) -> tensor<64x16xi1, #blocked0> + %mbar0 = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : !tt.ptr + %at0 = triton_gpu.insert_slice_async %20, %at, %ci0, %maska {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<64x16x!tt.ptr, #blocked0> -> tensor<1x64x16xf16, #shared0> + triton_nvidia_gpu.mbarrier_arrive %mbar0 {trackAsyncOp = true} : !tt.ptr + triton_nvidia_gpu.mbarrier_wait %mbar0, %i1_false : !tt.ptr + // triton_gpu.async_wait {num = 1 : i32} + %21 = triton_gpu.extract_slice %at0[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0> + %23 = tt.addptr %12, %10 : tensor<16x1x!tt.ptr, #blocked2>, tensor<16x1xi32, #blocked2> + %24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr, #blocked2>) -> tensor<16x64x!tt.ptr, #blocked2> + %25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr, #blocked2>, tensor<16x64xi32, #blocked2> + %bt = triton_gpu.alloc_tensor : tensor<1x16x64xf16, #shared1> + %maskb = tt.splat %i1_true : (i1) -> tensor<16x64xi1, #blocked2> + %mbar1 = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : !tt.ptr + %bt0 = triton_gpu.insert_slice_async %25, %bt, %ci0, %maskb {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<16x64x!tt.ptr, #blocked2> -> tensor<1x16x64xf16, #shared1> + triton_nvidia_gpu.mbarrier_arrive %mbar1 {trackAsyncOp = true} : !tt.ptr + triton_nvidia_gpu.mbarrier_wait %mbar1, %i1_false : !tt.ptr + // triton_gpu.async_wait {num = 1 : i32} + %28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1> + %29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1> + %31 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> + %32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> + %33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1> + %34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1> + %35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> + %37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1> + tt.store %37, %38 : tensor<64x64xf32, #blocked1> + return + } +} diff --git a/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_mbarrier_vec_64_64_16_f16.ttgir b/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_mbarrier_vec_64_64_16_f16.ttgir new file mode 100644 index 000000000000..ee4909a9ca48 --- /dev/null +++ b/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_mbarrier_vec_64_64_16_f16.ttgir @@ -0,0 +1,69 @@ +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>> + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0> + %3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1> + %4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0> + %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked0> + %6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0> + %8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0> + %9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2> + %11 = tt.splat %arg4 : (i32) -> tensor<1x64xi32, #blocked2> + %12 = tt.splat %arg1 : (!tt.ptr) -> tensor<16x1x!tt.ptr, #blocked2> + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2> + %bstride1 = arith.muli %11, %15 : tensor<1x64xi32, #blocked2> + %16 = tt.broadcast %bstride1 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2> + %17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0> + %18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr, #blocked0>, tensor<64x1xi32, #blocked0> + %19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr, #blocked0>) -> tensor<64x16x!tt.ptr, #blocked0> + %20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr, #blocked0>, tensor<64x16xi32, #blocked0> + %ci0 = arith.constant 0 : i32 + %ci1 = arith.constant 1 : i32 + %i1_true = arith.constant 1 : i1 + %i1_false = arith.constant 0 : i1 + %at = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0> + %maska = tt.splat %i1_true : (i1) -> tensor<64x16xi1, #blocked0> + %mbar = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : tensor<2xi64, #shared0> + %mbar0_s = triton_nvidia_gpu.extract_mbarrier %mbar[%ci0] : tensor<2xi64, #shared0>, i32 -> !tt.ptr + %at0 = triton_gpu.insert_slice_async %20, %at, %ci0, %maska {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<64x16x!tt.ptr, #blocked0> -> tensor<1x64x16xf16, #shared0> + triton_nvidia_gpu.mbarrier_arrive %mbar0_s {trackAsyncOp = true} : !tt.ptr + triton_nvidia_gpu.mbarrier_wait %mbar0_s, %i1_false : !tt.ptr + // triton_gpu.async_wait {num = 1 : i32} + %21 = triton_gpu.extract_slice %at0[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0> + %23 = tt.addptr %12, %10 : tensor<16x1x!tt.ptr, #blocked2>, tensor<16x1xi32, #blocked2> + %24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr, #blocked2>) -> tensor<16x64x!tt.ptr, #blocked2> + %25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr, #blocked2>, tensor<16x64xi32, #blocked2> + %bt = triton_gpu.alloc_tensor : tensor<1x16x64xf16, #shared1> + %maskb = tt.splat %i1_true : (i1) -> tensor<16x64xi1, #blocked2> + %mbar1_s = triton_nvidia_gpu.extract_mbarrier %mbar[%ci1] : tensor<2xi64, #shared0>, i32 -> !tt.ptr + %bt0 = triton_gpu.insert_slice_async %25, %bt, %ci0, %maskb {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<16x64x!tt.ptr, #blocked2> -> tensor<1x16x64xf16, #shared1> + triton_nvidia_gpu.mbarrier_arrive %mbar1_s {trackAsyncOp = true} : !tt.ptr + triton_nvidia_gpu.mbarrier_wait %mbar1_s, %i1_false : !tt.ptr + // triton_gpu.async_wait {num = 1 : i32} + %28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1> + %29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1> + %31 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> + %32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> + %33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1> + %34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1> + %35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> + %37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1> + tt.store %37, %38 : tensor<64x64xf32, #blocked1> + return + } +} diff --git a/python/test/unit/hopper/ttgir_tests/wgmma_tma_64_64_16_f16.ttgir b/python/test/unit/hopper/ttgir_tests/wgmma_tma_64_64_16_f16.ttgir new file mode 100644 index 000000000000..b05573e17b4e --- /dev/null +++ b/python/test/unit/hopper/ttgir_tests/wgmma_tma_64_64_16_f16.ttgir @@ -0,0 +1,64 @@ +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], CTAsPerCGA = [1, 1], order = [0, 1], CTAOrder = [1, 0]}> +#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%aBasePtr : !tt.ptr {tt.divisibility = 16 : i32}, + %bBasePtr : !tt.ptr {tt.divisibility = 16 : i32}, + %cBasePtr : !tt.ptr {tt.divisibility = 16 : i32}, + %sizeM : i32 {tt.divisibility = 16 : i32}, + %sizeN : i32 {tt.divisibility = 16 : i32}, + %sizeK : i32 {tt.divisibility = 16 : i32}, + %aStride0 : i32 {tt.divisibility = 16 : i32}, + %aStride1 : i32 {tt.divisibility = 16 : i32}, + %bStride0 : i32 {tt.divisibility = 16 : i32}, + %bStride1 : i32 {tt.divisibility = 16 : i32}, + %cStride0 : i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + %ci0 = arith.constant 0 : i32 + %ci1 = arith.constant 1 : i32 + %i1_true = arith.constant 1 : i1 + %aCoord0 = arith.constant 0 : i32 + %aCoord1 = arith.constant 0 : i32 + %bCoord0 = arith.constant 0 : i32 + %bCoord1 = arith.constant 0 : i32 + %mbar = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32, txCount = 2048 : i32 } : tensor<2xi64, #shared0> + %mbar_a = triton_nvidia_gpu.extract_mbarrier %mbar[%ci0] : tensor<2xi64, #shared0>, i32 -> !tt.ptr + %mbar_b = triton_nvidia_gpu.extract_mbarrier %mbar[%ci1] : tensor<2xi64, #shared0>, i32 -> !tt.ptr + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1> + %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + + // Load A + %a_smem = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0> + %a_smem_loaded = triton_gpu.load_tile_async %aBasePtr[%sizeM, %sizeK][%aStride0, %aStride1][%aCoord0, %aCoord1], %mbar_a, %a_smem[%ci0] {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, order = dense<[1, 0]> : tensor<2xi32>} : !tt.ptr -> tensor<1x64x16xf16, #shared0> + triton_nvidia_gpu.mbarrier_arrive %mbar_a {trackAsyncOp = false} : !tt.ptr + triton_nvidia_gpu.mbarrier_wait %mbar_a, %i1_true : !tt.ptr + %21 = triton_gpu.extract_slice %a_smem_loaded[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0> + + // Load B + %b_smem = triton_gpu.alloc_tensor : tensor<1x16x64xf16, #shared1> + %b_smem_loaded= triton_gpu.load_tile_async %bBasePtr[%sizeK, %sizeN][%bStride0, %bStride1][%bCoord0, %bCoord1], %mbar_b, %b_smem[%ci0] {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, order = dense<[0, 1]> : tensor<2xi32>} : !tt.ptr -> tensor<1x16x64xf16, #shared1> + triton_nvidia_gpu.mbarrier_arrive %mbar_b {trackAsyncOp = false} : !tt.ptr + triton_nvidia_gpu.mbarrier_wait %mbar_b, %i1_true : !tt.ptr + %28 = triton_gpu.extract_slice %b_smem_loaded[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1> + + // Calling MMA + %29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + + // Epilogue + %30 = tt.splat %cStride0: (i32) -> tensor<64x1xi32, #blocked1> + %31 = tt.splat %cBasePtr: (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> + %32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> + %33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1> + %34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1> + %35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> + %37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1> + tt.store %37, %38 : tensor<64x64xf32, #blocked1> + return + } +} diff --git a/python/test/unit/language/print_helper.py b/python/test/unit/language/print_helper.py index afdd12960737..feb0d219d7f2 100644 --- a/python/test/unit/language/print_helper.py +++ b/python/test/unit/language/print_helper.py @@ -28,6 +28,11 @@ def kernel_static_print(X, Y, BLOCK: tl.constexpr): tl.store(Y + tl.arange(0, BLOCK), x) +@triton.jit +def kernel_no_arg_print(): + print("", tl.program_id(0)) + + def test_print(func: str, data_type: str): shape = (128, ) # limit the range of integers so that the sum does not overflow @@ -39,7 +44,11 @@ def test_print(func: str, data_type: str): kernel_print[(1,)](x, y, BLOCK=shape[0]) elif func == "static_print": kernel_static_print[(1,)](x, y, BLOCK=shape[0]) - assert_close(y, x) + elif func == "no_arg_print": + kernel_no_arg_print[(1,)](num_warps=4) + + if func != "no_arg_print": + assert_close(y, x) if __name__ == "__main__": diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 8d73c02638df..da329be6267b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -21,6 +21,10 @@ dtypes_with_bfloat16 = dtypes + ['bfloat16'] torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16'] +# TODO: enable multiple cta cluster testing. +# num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1] +num_ctas_list = [1] + def _bitwidth(dtype: str) -> int: # ex.: "int64" -> 64 @@ -42,7 +46,7 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h high = iinfo.max if high is None else min(high, iinfo.max) dtype = getattr(np, dtype_str) x = rs.randint(low, high, shape, dtype=dtype) - x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out. + x[x == 0] = 1 # Workaround. Never return zero so tests of division don't error out. return x elif dtype_str and 'float8' in dtype_str: x = rs.randint(20, 40, shape, dtype=np.int8) @@ -119,37 +123,49 @@ def check_type_supported(dtype, device): cc = torch.cuda.get_device_capability() if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") + if cc[0] < 9 and (dtype is tl.float8e4nv or dtype == "float8e4"): + pytest.skip("float8e4 is only supported on NVGPU with cc >= 90") class MmaLayout: - def __init__(self, version, warps_per_cta): + def __init__(self, version, warps_per_cta, ctas_per_cga, cta_split_num, cta_order, instr_shape): self.version = version - self.warps_per_cta = warps_per_cta + self.warps_per_cta = str(warps_per_cta) + self.ctas_per_cga = str(ctas_per_cga) + self.cta_split_num = str(cta_split_num) + self.cta_order = str(cta_order) + self.instr_shape = str(instr_shape) def __str__(self): - return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}}}>" + return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>" class BlockedLayout: - def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order): - self.sz_per_thread = size_per_thread - self.threads_per_warp = threads_per_warp - self.warps_per_cta = warps_per_cta - self.order = order + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): + self.sz_per_thread = str(size_per_thread) + self.threads_per_warp = str(threads_per_warp) + self.warps_per_cta = str(warps_per_cta) + self.order = str(order) + self.ctas_per_cga = str(ctas_per_cga) + self.cta_split_num = str(cta_split_num) + self.cta_order = str(cta_order) def __str__(self): - return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>" + return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" class SharedLayout: - def __init__(self, vec, per_phase, max_phase, order): + def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order): self.vec = str(vec) self.per_phase = str(per_phase) self.max_phase = str(max_phase) self.order = str(order) + self.ctas_per_cga = str(ctas_per_cga) + self.cta_split_num = str(cta_split_num) + self.cta_order = str(cta_order) def __str__(self): - return f"#triton_gpu.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}}}>" + return f"#triton_gpu.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" @pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"]) @@ -165,7 +181,7 @@ def kernel(X, SIZE: tl.constexpr): # generic test functions -def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): +def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda', num_ctas=1): check_type_supported(dtype_x, device) # early return if dtype_x is not supported SIZE = 128 # define the kernel / launch-grid @@ -187,7 +203,7 @@ def kernel(Z, X, SIZE: tl.constexpr): # triton result x_tri = to_triton(x, device=device, dst_type=dtype_x) z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x) - kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4) + kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) # compare np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) @@ -223,7 +239,7 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: return overrides.get(key) -def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None): +def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1, y_low=None, y_high=None): check_type_supported(dtype_x, device) # early return if dtype_x is not supported check_type_supported(dtype_y, device) SIZE = 128 @@ -255,7 +271,8 @@ def kernel(Z, X, Y, SIZE: tl.constexpr): x_tri = to_triton(x, device=device, dst_type=dtype_x) y_tri = to_triton(y, device=device, dst_type=dtype_y) z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) - kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4) + kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, + num_warps=4, num_ctas=num_ctas) np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01) @@ -295,7 +312,8 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: for dtype_x in dtypes_with_bfloat16 for dtype_y in dtypes_with_bfloat16 ]) -def test_bin_op(dtype_x, dtype_y, op, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): expr = f' x {op} y' if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. @@ -313,28 +331,107 @@ def test_bin_op(dtype_x, dtype_y, op, device): numpy_expr = None if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): with pytest.raises(AssertionError, match='Not equal to tolerance'): - _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) + _test_binary( + dtype_x, + dtype_y, + expr, + numpy_expr, + device=device, + num_ctas=num_ctas) elif (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or (dtype_x in uint_dtypes and dtype_y in int_dtypes))): with pytest.raises(triton.CompilationError) as exc_info: - _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__)) else: - _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) + _test_binary( + dtype_x, + dtype_y, + expr, + numpy_expr, + device=device, + num_ctas=num_ctas) + + +@pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]]) +def test_addptr(dtype, order, device): + check_type_supported(dtype, device) + + @triton.jit + def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr): + offs = tl.arange(0, SIZE) + if ORDER == 0: + tl.store(y + offs, tl.load(x + offs)) + else: + tl.store(offs + y, tl.load(offs + x)) + + SIZE = 1024 + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + x_tri = to_triton(x, dst_type=dtype, device=device) + y_tri = to_triton(y, dst_type=dtype, device=device) + y = x + kernel[1,](x_tri, y_tri, order, SIZE) + np.testing.assert_allclose(y, to_numpy(y_tri)) @pytest.mark.parametrize("dtype_x, dtype_y", [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] + [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes] ) -def test_floordiv(dtype_x, dtype_y, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_floordiv(dtype_x, dtype_y, num_ctas, device): # Triton has IEEE, not numpy/torch, semantics for %, and those carry # through to //, so we have to use a nonstandard expression to get a # reference result for //. expr = 'x // y' numpy_expr = '((x - np.fmod(x, y)) / y)' - _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) + _test_binary( + dtype_x, + dtype_y, + expr, + numpy_expr, + device=device, + num_ctas=num_ctas) + + +def test_unsigned_name_mangling(device='cuda'): + # Test that uint32 and int32 are mangled differently by the compiler + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(O1, O2, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + out1 = tl.abs(x) # uint32 -> nop + out2 = tl.abs(-y) # int32 -> should have an effect + tl.store(O1 + off, out1) + tl.store(O2 + off, out2) + + dtype_x = 'uint32' + dtype_y = 'int32' + # inputs + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs) + # reference result + expect = (np.abs(x), np.abs(-y)) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + y_tri = to_triton(y, device=device, dst_type=dtype_y) + actual = tuple( + to_triton(np.empty_like(e), device=device) + for e in expect + ) + kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4) + + # Bitwise op, so expect exact equality + assert (expect[0] == to_numpy(actual[0])).all() + assert (expect[1] == to_numpy(actual[1])).all() def test_unsigned_name_mangling(device): @@ -383,7 +480,8 @@ def kernel(O1, O2, X, Y, SIZE: tl.constexpr): for dtype_x in dtypes + dtypes_with_bfloat16 for dtype_y in dtypes + dtypes_with_bfloat16 ]) -def test_bitwise_op(dtype_x, dtype_y, op, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device): expr = f'x {op} y' if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' @@ -393,11 +491,17 @@ def test_bitwise_op(dtype_x, dtype_y, op, device): numpy_expr = None if 'float' in dtype_x + dtype_y: with pytest.raises(triton.CompilationError) as exc_info: - _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device) + _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device, num_ctas=num_ctas) # The CompilationError must have been caused by a C++ exception with this text. assert re.match('invalid operands of type', str(exc_info.value.__cause__)) else: - _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) + _test_binary( + dtype_x, + dtype_y, + expr, + numpy_expr, + device=device, + num_ctas=num_ctas) @pytest.mark.parametrize("dtype_x, dtype_y, op", [ @@ -406,7 +510,8 @@ def test_bitwise_op(dtype_x, dtype_y, op, device): for dtype_x in int_dtypes + uint_dtypes for dtype_y in int_dtypes + uint_dtypes ]) -def test_shift_op(dtype_x, dtype_y, op, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_shift_op(dtype_x, dtype_y, op, num_ctas, device): expr = f'x {op} y' bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y)) if dtype_x.startswith('int'): @@ -414,7 +519,7 @@ def test_shift_op(dtype_x, dtype_y, op, device): else: dtype_z = f'uint{bw}' numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})' - _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65) + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, y_low=0, y_high=65) # --------------- @@ -439,7 +544,8 @@ def test_shift_op(dtype_x, dtype_y, op, device): ('nan', 'nan')] ]) -def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): expr = f'x {op} y' if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' @@ -447,7 +553,36 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device): numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' else: numpy_expr = None - _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device) + _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device, num_ctas=num_ctas) + + +# --------------- +# test broadcast +# --------------- +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16) +def test_broadcast(dtype): + @triton.jit + def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr): + offset1 = tl.arange(0, M) + offset2 = tl.arange(0, N) + x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :]) + y = tl.load(y_ptr + offset2) + _, y_broadcasted = tl.broadcast(x, y) + tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted) + + M = 32 + N = 64 + rs = RandomState(17) + x = numpy_random((M, N), dtype_str=dtype, rs=rs) + y = numpy_random(N, dtype_str=dtype, rs=rs) + _, y_broadcasted_np = np.broadcast_arrays(x, y) + + x_tri = to_triton(x, device='cuda', dst_type=dtype) + y_tri = to_triton(y, device='cuda', dst_type=dtype) + y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device='cuda', dst_type=dtype) + + broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) + assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() # --------------- @@ -592,7 +727,8 @@ def _kernel(dst): # test where # --------------- @pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) -def test_where(dtype, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_where(dtype, num_ctas, device): select_ptrs = False if dtype == "*int32": dtype = "int64" @@ -633,7 +769,7 @@ def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device=device, dst_type=dtype) grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),) - where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=False) + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=False, num_ctas=num_ctas) assert (z == to_numpy(z_tri)).all() if select_ptrs: where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=True) @@ -641,7 +777,8 @@ def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, assert (z == to_numpy(z_tri)).all() -def test_where_broadcast(device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_where_broadcast(num_ctas, device): @triton.jit def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] @@ -672,7 +809,7 @@ def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device=device, dst_type=dtype) where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE) assert (z == to_numpy(z_tri)).all() - where_scalar_condition[(1,)](x_tri, z_tri, SIZE) + where_scalar_condition[(1,)](x_tri, z_tri, SIZE, num_ctas=num_ctas) z = np.where(0, x, 0) assert (z == to_numpy(z_tri)).all() @@ -686,8 +823,9 @@ def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): ] + [ (dtype_x, ' ~x') for dtype_x in int_dtypes ]) -def test_unary_op(dtype_x, expr, device): - _test_unary(dtype_x, expr, device=device) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_unary_op(dtype_x, expr, num_ctas, device): + _test_unary(dtype_x, expr, device=device, num_ctas=num_ctas) # ---------------- # test math ops @@ -711,7 +849,7 @@ def test_abs(dtype_x, device): _test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device) -@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4, tl.float8e5]) +@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4nv, tl.float8e5]) def test_abs_fp8(in_dtype, device): @triton.jit @@ -760,7 +898,8 @@ def make_ptr_str(name, shape): ':, :, None'] for d in ['int32', 'uint32', 'uint16'] ]) -def test_index1d(expr, dtype_str, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_index1d(expr, dtype_str, num_ctas, device): rank_x = expr.count(':') rank_y = expr.count(',') + 1 shape_x = [32 for _ in range(rank_x)] @@ -802,7 +941,8 @@ def generate_kernel(shape_x, shape_z): def catch_compilation_error(kernel): try: - kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) + kernel[(1, )](z_tri, x_tri, num_warps=1, + SIZE=shape_x[0], num_ctas=num_ctas) except triton.CompilationError as e: np.testing.assert_(True) except BaseException: @@ -1004,20 +1144,21 @@ def kernel(X, Z): assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"] -def test_atomic_rmw_predicate(device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_atomic_rmw_predicate(num_ctas, device): @triton.jit def kernel(X): val = tl.program_id(0) if val < 64: tl.atomic_max(X, val) x = torch.zeros((1,), device=device, dtype=torch.int32) - kernel[(4096,)](x) + kernel[(4096,)](x, num_ctas=num_ctas) assert x.item() == 63 -@pytest.mark.parametrize("shape, axis", - [(shape, axis) for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32)] for axis in [0, 1]]) -def test_tensor_atomic_rmw(shape, axis, device): +@pytest.mark.parametrize("shape, axis, num_ctas", + [(shape, axis, num_ctas) for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] for axis in [0, 1] for num_ctas in num_ctas_list]) +def test_tensor_atomic_rmw(shape, axis, num_ctas, device): shape0, shape1 = shape # triton kernel @@ -1039,11 +1180,12 @@ def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr) x_tri = to_triton(x, device=device) z_shape = (shape0, ) if axis == 1 else (shape1, ) z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device) - kernel[(1,)](z_tri, x_tri, axis, shape0, shape1) + kernel[(1,)](z_tri, x_tri, axis, shape0, shape1, num_ctas=num_ctas) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) -def test_tensor_atomic_rmw_block(device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_tensor_atomic_rmw_block(num_ctas, device): shape = (8, 8) @triton.jit @@ -1055,12 +1197,13 @@ def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): x = X + offs tl.atomic_min(x, val) x = torch.ones((8, 8), device=device, dtype=torch.float32) - kernel[(2,)](x, shape[0], shape[1]) + kernel[(2,)](x, shape[0], shape[1], num_ctas=num_ctas) assert torch.min(x).item() == 0.0 @pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) -def test_atomic_cas(sem, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_atomic_cas(sem, num_ctas, device): # 1. make sure that atomic_cas changes the original value (Lock) @triton.jit def change_value(Lock): @@ -1086,7 +1229,7 @@ def serialized_add(data, Lock, SEM: tl.constexpr): Lock = torch.zeros((1,), device=device, dtype=torch.int32) data = torch.zeros((128,), device=device, dtype=torch.float32) ref = torch.full((128,), 64.0) - h = serialized_add[(64,)](data, Lock, SEM=sem) + h = serialized_add[(64,)](data, Lock, SEM=sem, num_ctas=num_ctas) sem_str = "acq_rel" if sem is None else sem np.testing.assert_allclose(to_numpy(data), to_numpy(ref)) assert f"atom.global.{sem_str}" in h.asm["ptx"] @@ -1112,7 +1255,8 @@ def serialized_add(data, Lock, SEM: tl.constexpr): ] + [ (f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64] ]) -def test_cast(dtype_x, dtype_z, bitcast, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_cast(dtype_x, dtype_z, bitcast, num_ctas, device): # bfloat16 on cc < 80 will not be tested check_type_supported(dtype_x, device) check_type_supported(dtype_z, device) @@ -1145,7 +1289,7 @@ def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr): z_tri = torch.empty((size,), dtype=getattr(torch, dtype_z), device=device) else: z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device) - kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, num_warps=1) + kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, num_warps=1, num_ctas=num_ctas) # torch result if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'): assert bitcast is False @@ -1182,7 +1326,8 @@ def kernel(X, Y, Z, N: tl.constexpr): @pytest.mark.parametrize("dtype_str", list(torch_dtypes)) -def test_store_constant(dtype_str, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_store_constant(dtype_str, num_ctas, device): check_type_supported(dtype_str, device) """Tests that boolean True is stored as 1""" @@ -1198,7 +1343,7 @@ def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): block_size = 128 ref = torch.ones([block_size], dtype=getattr(torch, dtype_str), device=device) output = torch.zeros([block_size], dtype=getattr(torch, dtype_str), device=device) - kernel[(1,)](output, block_size, BLOCK_SIZE=block_size) + kernel[(1,)](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) assert torch.all(output == ref) @@ -1236,8 +1381,8 @@ def convert_float_to_float32(fp: torch.tensor, dtype=None): extended_exp = ((1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width # special cases, exp is 0b11..1 - if dtype in [tl.float8e4, tl.float8e4b15]: - # float8e4m3 does not have infinities + if dtype in [tl.float8e4nv, tl.float8e4b15]: + # float8e4m3nv does not have infinities output[fp == 0b01111111] = torch.nan output[fp == 0b11111111] = torch.nan else: @@ -1264,43 +1409,39 @@ def test_convert_float16_to_float32(in_dtype, device): def serialize_fp8(np_data, in_dtype): - return np_data -# def serialize_fp8(np_data, in_dtype): -# if in_dtype == tl.float8e4b15: -# # triton's f8e4b15 format is optimized for software emulation -# # as a result, each pack of 4xfp8 values: -# # s0b0s1b1s2b2s3b3 (for s, b sign and bits respectively) -# # is actually internally stored as -# # s0s2b0b2s1s3b1b3 -# # we apply the conversion here -# f8x4 = np_data.view(np.uint32) -# s = [(f8x4 & (0x80000000 >> i)) << i for i in range(0, 32, 8)] -# b = [(f8x4 & (0x7f000000 >> i)) << i for i in range(0, 32, 8)] -# signs = (s[0] >> 0) | (s[1] >> 16) | (s[2] >> 1) | (s[3] >> 17) -# bits = (b[0] >> 1) | (b[1] >> 17) | (b[2] >> 8) | (b[3] >> 24) -# # tensor of triton fp8 data -# return (signs | bits).view(np.int8) -# else: -# return np_data + if in_dtype == tl.float8e4b15x4: + # triton's f8e4b15 format is optimized for software emulation + # as a result, each pack of 4xfp8 values: + # s0b0s1b1s2b2s3b3 (for s, b sign and bits respectively) + # is actually internally stored as + # s0s2b0b2s1s3b1b3 + # we apply the conversion here + f8x4 = np_data.view(np.uint32) + s = [(f8x4 & (0x80000000 >> i)) << i for i in range(0, 32, 8)] + b = [(f8x4 & (0x7f000000 >> i)) << i for i in range(0, 32, 8)] + signs = (s[0] >> 0) | (s[1] >> 16) | (s[2] >> 1) | (s[3] >> 17) + bits = (b[0] >> 1) | (b[1] >> 17) | (b[2] >> 8) | (b[3] >> 24) + # tensor of triton fp8 data + return (signs | bits).view(np.int8) + else: + return np_data # inverse of `serialize_fp8` def deserialize_fp8(np_data, in_dtype): - return np_data -# def deserialize_fp8(np_data, in_dtype): -# if in_dtype == tl.float8e4b15: -# f8x4 = np_data.view(np.uint32) -# s = [(f8x4 & (0x80000000 >> i)) << i for i in [0, 16, 1, 17]] -# b = [(f8x4 & (0x7f000000 >> i)) << i for i in [1, 17, 8, 24]] -# signs = (s[0] >> 0) | (s[1] >> 8) | (s[2] >> 16) | (s[3] >> 24) -# bits = (b[0] >> 0) | (b[1] >> 8) | (b[2] >> 16) | (b[3] >> 24) -# return (signs | bits).view(np.int8) -# else: -# return np_data - - -@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4, tl.float8e5]) + if in_dtype == tl.float8e4b15x4: + f8x4 = np_data.view(np.uint32) + s = [(f8x4 & (0x80000000 >> i)) << i for i in [0, 16, 1, 17]] + b = [(f8x4 & (0x7f000000 >> i)) << i for i in [1, 17, 8, 24]] + signs = (s[0] >> 0) | (s[1] >> 8) | (s[2] >> 16) | (s[3] >> 24) + bits = (b[0] >> 0) | (b[1] >> 8) | (b[2] >> 16) | (b[3] >> 24) + return (signs | bits).view(np.int8) + else: + return np_data + + +@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4b15x4, tl.float8e4nv, tl.float8e5]) @pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32]) def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device): """ @@ -1309,6 +1450,7 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device): - conversion tri_fp8 = convert(input=tri_fp16, out=out_dtype) matches the original this is only possible if both conversions are correct """ + check_type_supported(in_dtype, device) check_type_supported(out_dtype, device) @triton.jit @@ -1346,8 +1488,6 @@ def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): def get_reduced_dtype(dtype_str, op): if op in ('argmin', 'argmax'): return 'int32' - if dtype_str in ['int8', 'uint8', 'int16', 'uint16']: - return 'int32' if dtype_str == 'bfloat16': return 'float32' return dtype_str @@ -1363,7 +1503,8 @@ def get_reduced_dtype(dtype_str, op): 'sum'] for dtype in dtypes_with_bfloat16 for shape in [32, 64, 128, 512]]) -def test_reduce1d(op, dtype_str, shape, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_reduce1d(op, dtype_str, shape, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested # triton kernel @@ -1409,7 +1550,7 @@ def kernel(X, Z, BLOCK: tl.constexpr): # triton result z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) - kernel[(1,)](x_tri, z_tri, BLOCK=shape) + kernel[(1,)](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas) z_tri = to_numpy(z_tri) # compare if op == 'sum': @@ -1452,7 +1593,8 @@ def kernel(X, Z, BLOCK: tl.constexpr): @pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2) -def test_reduce2d(op, dtype_str, shape, axis, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested # triton kernel @@ -1492,7 +1634,8 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexp ret_numel = 1 if axis is None else shape[1 - axis] z_tri = to_triton(numpy_random((ret_numel,), dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) - kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) + kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], + BLOCK_N=shape[1], AXIS=axis, num_ctas=num_ctas) z_tri = to_numpy(z_tri) # compare if op == 'sum': @@ -1559,17 +1702,17 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexp scan_layouts = [ - BlockedLayout([1, 4], [4, 8], [4, 1], [0, 1]), - BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1]), - BlockedLayout([4, 1], [4, 8], [1, 4], [0, 1]), - BlockedLayout([2, 2], [4, 8], [2, 2], [0, 1]), - BlockedLayout([2, 2], [8, 4], [2, 2], [0, 1]), - - BlockedLayout([1, 4], [4, 8], [4, 1], [1, 0]), - BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0]), - BlockedLayout([4, 1], [4, 8], [1, 4], [1, 0]), - BlockedLayout([2, 2], [4, 8], [2, 2], [1, 0]), - BlockedLayout([2, 2], [8, 4], [2, 2], [1, 0]), + BlockedLayout([1, 4], [4, 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [4, 8], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [4, 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [8, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + + BlockedLayout([1, 4], [4, 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [4, 8], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [4, 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [8, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), ] @@ -1579,7 +1722,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexp def test_scan_layouts(M, N, src_layout, axis, device): ir = f""" #blocked = {src_layout} - module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ + module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> @@ -1628,11 +1771,11 @@ def test_scan_layouts(M, N, src_layout, axis, device): layouts = [ - BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0]), - BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1]), - BlockedLayout([4, 4], [2, 16], [4, 1], [1, 0]), - MmaLayout(version=(2, 0), warps_per_cta=[4, 1]), - MmaLayout(version=(2, 0), warps_per_cta=[2, 2]) + BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 4], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), + MmaLayout(version=(2, 0), warps_per_cta=[2, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]) ] @@ -1643,10 +1786,11 @@ def test_reduce_layouts(M, N, src_layout, axis, device): rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" rdims_1d = f"{N}" if axis == 0 else f"{M}" store_range = "%7" if axis == 0 else "%1" + blocked = BlockedLayout([1, 1], [32, 1], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]) ir = f""" - #blocked = #triton_gpu.blocked<{{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}}> + #blocked = {blocked} #src = {src_layout} - module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{ + module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked> @@ -1701,9 +1845,9 @@ def test_reduce_layouts(M, N, src_layout, axis, device): layouts = [ - BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]), - BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]), - MmaLayout(version=(2, 0), warps_per_cta=[4, 1]) + BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]) ] @@ -1712,7 +1856,7 @@ def test_reduce_layouts(M, N, src_layout, axis, device): def test_store_op(M, src_layout, device): ir = f""" #src = {src_layout} - module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{ + module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> %1 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #triton_gpu.slice<{{dim = 1, parent = #src}}>> @@ -1747,9 +1891,9 @@ def test_store_op(M, src_layout, device): layouts = [ - BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]), - BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]), - MmaLayout(version=(2, 0), warps_per_cta=[4, 1]) + BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]) ] @@ -1762,7 +1906,7 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): ir = f""" #dst = {dst_layout} #src = {src_layout} - module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{ + module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>> %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>> @@ -1806,10 +1950,10 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): layouts = [ - BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]), - BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]), - BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]), - BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1]) + BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]) ] @@ -1830,7 +1974,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): tt.reduce.return %14 : i32""" ir = f""" #src = {src_layout} - module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{ + module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> @@ -1916,7 +2060,8 @@ def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr): for dtype in ['float8e4b15', 'float16', 'float32'] for shape in [(64, 64), (128, 128)] for perm in [(1, 0)]]) -def test_permute(dtype_str, shape, perm, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_permute(dtype_str, shape, perm, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested # triton kernel @@ -1937,10 +2082,12 @@ def kernel(X, stride_xm, stride_xn, x_tri = to_triton(x, device=device, dst_type=dtype_str) pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), z_tri, z_tri.stride(1), z_tri.stride(0), - BLOCK_M=shape[0], BLOCK_N=shape[1]) + BLOCK_M=shape[0], BLOCK_N=shape[1], + num_ctas=num_ctas) pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0), z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1), - BLOCK_M=shape[0], BLOCK_N=shape[1]) + BLOCK_M=shape[0], BLOCK_N=shape[1], + num_ctas=num_ctas) # numpy result if dtype_str == 'float8e4b15': ty = tl.float8e4b15 @@ -1960,6 +2107,7 @@ def kernel(X, stride_xm, stride_xn, assert 'ld.global.v4' in ptx assert 'st.global.v4' in ptx + # --------------- # test dot # --------------- @@ -1993,7 +2141,8 @@ def kernel(X, stride_xm, stride_xn, ('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')]]) -def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, num_ctas, device): check_cuda_only(device) capability = torch.cuda.get_device_capability() @@ -2002,7 +2151,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o if capability[0] < 8: if in_dtype == 'int8': pytest.skip("Only test int8 on devices with sm >= 80") - elif in_dtype == 'float32' and allow_tf32: + elif allow_tf32: pytest.skip("Only test tf32 on devices with sm >= 80") if capability[0] == 7: if (M, N, K, num_warps) == (128, 256, 32, 8): @@ -2013,18 +2162,22 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + if num_ctas > 1 and in_dtype == 'int8': + # FIXME: mma v2 with num_ctas > 1 does not work + pytest.skip() + # triton kernel @triton.jit def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, stride_wl, Z, stride_zm, stride_zn, - out_dtype: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, ALLOW_TF32: tl.constexpr, DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, - COL_A: tl.constexpr, COL_B: tl.constexpr): + COL_A: tl.constexpr, COL_B: tl.constexpr, + out_dtype: tl.constexpr = tl.float32): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) off_l = tl.arange(0, BLOCK_N) @@ -2052,7 +2205,7 @@ def kernel(X, stride_xm, stride_xk, z = num / den[:, None] if CHAIN_DOT: w = tl.load(Ws) - z = tl.dot(z.to(w.dtype), w, out_dtype=out_dtype) + z = tl.dot(z.to(w.dtype), w, allow_tf32=ALLOW_TF32, out_dtype=out_dtype) tl.store(Zs, z) # input rs = RandomState(17) @@ -2098,7 +2251,6 @@ def kernel(X, stride_xm, stride_xk, y_tri, y_tri.stride(0), y_tri.stride(1), w_tri, w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), - out_dtype, COL_A=col_a, COL_B=col_b, BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, ADD_MATRIX=epilogue == 'add-matrix', @@ -2107,14 +2259,21 @@ def kernel(X, stride_xm, stride_xk, DO_SOFTMAX=epilogue == 'softmax', CHAIN_DOT=epilogue == 'chain-dot', ALLOW_TF32=allow_tf32, - num_warps=num_warps) + num_warps=num_warps, num_ctas=num_ctas, + out_dtype=out_dtype) if epilogue == 'softmax' and (in_dtype != 'float32' or allow_tf32): ptx = pgm.asm["ptx"] start = ptx.find("shfl.sync") end = ptx.find("cvt.rn.f16.f32") red_code = ptx[start:end] assert len(red_code) > 0 - assert "shared" not in red_code + import os + enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower() + enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower() + # skip this check on hopper because there are some functions whose name contain "shared" in ptx. + # TODO: we should eliminate these unused functions in ptx code. + if not (enable_mmav3 in ["on", "true", "1"] and enable_tma in ["on", "true", "1"]): + assert "shared" not in red_code assert "bar.sync" not in red_code # torch result if in_dtype == 'int8': @@ -2151,17 +2310,28 @@ def kernel(X, stride_xm, stride_xk, assert 'ld.global.v4' in ptx assert 'st.global.v4' in ptx if in_dtype == 'float32' and allow_tf32: - assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx - elif in_dtype == 'float32' and allow_tf32: - assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx) + elif in_dtype == 'float16' and out_dtype == tl.float32: + if capability[0] == 7 and capability[1] == 5: # Turing + assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.f16.f16', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.f16.f16', ptx) + elif in_dtype == 'float16' and out_dtype == tl.float16: + if capability[0] == 7 and capability[1] == 5: # Turing + assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f16.f16.f16', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f16.f16.f16', ptx) elif in_dtype == 'int8': - assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx - elif out_dtype == tl.float16: - assert 'mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16' in ptx + assert 'wgmma.mma_async.sync.aligned' in ptx or\ + 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx @pytest.mark.parametrize('in_dtype', ['float32']) def test_dot_mulbroadcastred(in_dtype, device): + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + pytest.skip("Requires sm >= 80 to run") + @triton.jit def kernel(Z, X, Y, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, @@ -2197,7 +2367,16 @@ def kernel(Z, X, Y, z_ref = np.matmul(x, y) np.testing.assert_allclose(z_ref, to_numpy(z_tri), atol=0.01) assert "tt.dot" in h.asm['ttir'] - assert "triton_gpu.async_wait {num = 2 : i32}" in h.asm['ttgir'] + # with option ENABLE_MMA_V3 on, we will not pipeline the load op for Y + # as the loaded value is in rowmajor. But MMAv3 requires it's second + # operand is in colmajor because transpose is not supported for MMAv3 + # with float32 input. + import os + enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower() + if enable_mmav3 in ["on", "true", "1"]: + assert "triton_gpu.async_wait {num = 1 : i32}" in h.asm['ttgir'] + else: + assert "triton_gpu.async_wait {num = 2 : i32}" in h.asm['ttgir'] @pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16']) @@ -2250,11 +2429,14 @@ def kernel(out_ptr): @pytest.mark.parametrize("dtype_str", ['float32', 'float16']) def test_dot_without_load(dtype_str, device): + capability = torch.cuda.get_device_capability() + allow_tf32 = capability[0] > 7 + @triton.jit - def _kernel(out): + def _kernel(out, ALLOW_TF32: tl.constexpr): a = GENERATE_TEST_HERE b = GENERATE_TEST_HERE - c = tl.dot(a, b) + c = tl.dot(a, b, allow_tf32=ALLOW_TF32) out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] tl.store(out_ptr, c) kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"}) @@ -2262,7 +2444,7 @@ def _kernel(out): b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) out_ref = torch.matmul(a, b) out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device=device) - kernel[(1,)](out) + kernel[(1,)](out, ALLOW_TF32=allow_tf32) assert torch.all(out == out_ref) # --------------- @@ -2271,7 +2453,8 @@ def _kernel(out): @pytest.mark.parametrize("start", [0, 1, 7, 16]) -def test_arange(start, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_arange(start, num_ctas, device): BLOCK = 128 z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) @@ -2281,7 +2464,7 @@ def _kernel(z, BLOCK: tl.constexpr, off = tl.arange(0, BLOCK) val = tl.arange(START, END) tl.store(z + off, val) - _kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK) + _kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK, num_ctas=num_ctas) z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device) np.testing.assert_allclose(to_numpy(z_tri), to_numpy(z_ref)) @@ -2291,7 +2474,8 @@ def _kernel(z, BLOCK: tl.constexpr, @pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [0, 1, 2, 3, 4]]) -def test_masked_load(dtype_str, size, size_diff, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_masked_load(dtype_str, size, size_diff, num_ctas, device): dtype = getattr(torch, dtype_str) check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested @@ -2316,7 +2500,7 @@ def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): mask_str = "mask=in_offsets < in_size, other=1" if size_diff > 0 else "None" kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"}) - kernel[(1,)](input, output, input_size, output_size) + kernel[(1,)](input, output, input_size, output_size, num_ctas=num_ctas) reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device))) # print((output - reference_out).nonzero()) @@ -2325,6 +2509,7 @@ def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): # Testing masked loads with an intermate copy to shared memory run. +# FIXME: Shape too small for ldmatrix when num_ctas=4 @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_masked_load_shared_memory(dtype, device): check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested @@ -2399,16 +2584,19 @@ def _kernel(dst, src, CACHE: tl.constexpr): @pytest.mark.parametrize("N", [16, 10, 11, 1024]) -def test_vectorization(N, device): - src = torch.empty(1024, device=device) - dst = torch.empty(1024, device=device) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_vectorization(N, num_ctas, device): + block_size = 1024 * num_ctas + src = torch.empty(block_size, device=device) + dst = torch.empty(block_size, device=device) @triton.jit def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(src + offsets, mask=offsets < N) tl.store(dst + offsets, x, mask=offsets < N) - pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0]) + pgm = _kernel[(1,)]( + dst, src, N=N, BLOCK_SIZE=block_size) ptx = pgm.asm["ptx"] if N % 16 == 0: assert "ld.global.v4.b32" in ptx @@ -2686,7 +2874,8 @@ def vecmul_kernel(ptr, n_elements, rep, type: tl.constexpr): @pytest.mark.parametrize("type", ["inline", "noinline"]) -def test_call(type, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_call(type, num_ctas, device): @triton.jit def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): @@ -2698,7 +2887,7 @@ def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): rand_val_tri = to_triton(rand_val, device=device) err_msg = "" try: - kernel[(size // 128,)](rand_val_tri, size, 3, 5, type) + kernel[(size // 128,)](rand_val_tri, size, 3, 5, type, num_ctas=num_ctas) except Exception as e: err_msg = str(e) @@ -2712,8 +2901,10 @@ def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): # test if # ------------- +# TODO(Keren): if_exp_dynamic -@pytest.mark.parametrize("if_type", ["if", "if_exp", "if_and_dynamic", "if_and_static"]) + +@pytest.mark.parametrize("if_type", ["if", "if_and_dynamic", "if_exp_static", "if_and_static"]) def test_if(if_type, device): @triton.jit @@ -2725,8 +2916,10 @@ def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr tl.store(Ret, tl.load(XTrue)) else: tl.store(Ret, tl.load(XFalse)) - elif IfType == "if_exp": - tl.store(Ret, tl.load(XTrue)) if pid % 2 else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_dynamic": + tl.store(Ret, tl.load(XTrue)) if pid % 2 == 0 else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_static": + tl.store(Ret, tl.load(XTrue)) if BoolVar else tl.store(Ret, tl.load(XFalse)) elif IfType == "if_and_dynamic": if BoolVar and pid % 2 == 0: tl.store(Ret, tl.load(XTrue)) @@ -2741,7 +2934,7 @@ def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr cond = torch.ones(1, dtype=torch.int32, device=device) x_true = torch.tensor([3.14], dtype=torch.float32, device=device) x_false = torch.tensor([1.51], dtype=torch.float32, device=device) - ret = torch.empty(1, dtype=torch.float32, device=device) + ret = torch.zeros(1, dtype=torch.float32, device=device) kernel[(1,)](cond, x_true, x_false, ret, if_type, True, 1) assert torch.equal(ret, x_true) @@ -2772,7 +2965,8 @@ def _kernel(dst): ('float32', 'math.pow', tl.math.libdevice_path()), ('float64', 'math.pow_dtype', tl.math.libdevice_path()), ('float64', 'math.norm4d', '')]) -def test_math_tensor(dtype_str, expr, lib_path, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_math_tensor(dtype_str, expr, lib_path, num_ctas, device): @triton.jit def kernel(X, Y, BLOCK: tl.constexpr): @@ -2816,7 +3010,7 @@ def kernel(X, Y, BLOCK: tl.constexpr): x_tri = to_triton(x, device=device) # triton result y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device=device) - kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}) + kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}, num_ctas=num_ctas) # compare if expr == 'math.ffs': np.testing.assert_equal(y_ref, to_numpy(y_tri)) @@ -2828,7 +3022,8 @@ def kernel(X, Y, BLOCK: tl.constexpr): [('float32', 'math.pow', ''), ('float64', 'math.pow_dtype', ''), ('float64', 'math.pow', tl.math.libdevice_path())]) -def test_math_scalar(dtype_str, expr, lib_path, device): +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_math_scalar(dtype_str, expr, lib_path, num_ctas, device): @triton.jit def kernel(X, Y, BLOCK: tl.constexpr): @@ -2855,10 +3050,64 @@ def kernel(X, Y, BLOCK: tl.constexpr): # triton result x_tri = to_triton(x, device=device)[0].item() y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device=device) - kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}) + kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}, num_ctas=num_ctas) # compare np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) + +# ----------------------- +# test inline asm +# ----------------------- + +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_inline_asm(num_ctas, device): + check_cuda_only(device) + + @triton.jit + def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + s = tl.full([BLOCK], n, tl.int32) + z = tl.inline_asm_elementwise("shf.l.wrap.b32 $0, $1, $2, $3;", "=r,r, r, r", [x, y, s], dtype=tl.int32, is_pure=True, pack=1) + tl.store(Z + tl.arange(0, BLOCK), z) + + shape = (128, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint32', rs=rs) + y = numpy_random(shape, dtype_str='uint32', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + n = 17 + z_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + kernel[(1,)](x_tri, y_tri, z_tri, n, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = (y << n) | (x >> (32 - n)) + # compare + np.testing.assert_equal(y_ref, to_numpy(z_tri)) + + +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_inline_asm_packed(num_ctas, device): + check_cuda_only(device) + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # shift 4x8bits values together. + y = tl.inline_asm_elementwise("and.b32 $0, $1, 0x1F1F1F1F; \ + shl.b32 $0, $0, 3;", + "=r,r", [x,], dtype=tl.int8, is_pure=True, pack=4) + tl.store(Y + tl.arange(0, BLOCK), y) + + shape = (512, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint8', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device) + kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = x << 3 + # compare + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + # ----------------------- # test control flow # ----------------------- @@ -2966,8 +3215,9 @@ def add_fn_static_cond(x, cond: tl.constexpr): return x + 1 +# TODO(Keren): if_exp @pytest.mark.parametrize("call_type", ["attribute", "attribute_jit", - "jit", "jit_if", "jit_ifexp", "jit_expr", + "jit", "jit_if", "jit_expr", "jit_static_cond", "jit_noinline", "jit_extern"]) def test_if_call(call_type, device): @triton.jit @@ -2998,7 +3248,7 @@ def kernel(Out, call_type: tl.constexpr): a = o a = add_fn_return(a, pid) o = a - elif call_type == "jit_ifexp": + elif call_type == "jit_if_exp": # ifexp expression if pid == 0: a = o @@ -3157,8 +3407,7 @@ def kernel(Out1, Out2): out2 = to_triton(np.zeros((1,), dtype=np.int64), device=device) h = kernel[(1,)](out1, out2) assert out2[0] > 0 - # 2 inlined globaltimers + one extra in the wrapper extern function - assert h.asm["ptx"].count("%globaltimer") == 3 + assert h.asm["ptx"].count("%globaltimer") == 2 def test_smid(device): @@ -3171,7 +3420,7 @@ def kernel(Out): out = to_triton(np.zeros((1024,), dtype=np.int32), device=device) h = kernel[(out.shape[0],)](out) assert out.sort()[0].unique().shape[0] > 0 - assert h.asm["ptx"].count("%smid") == 2 + assert h.asm["ptx"].count("%smid") == 1 # ----------------------- # test layout conversions @@ -3180,24 +3429,24 @@ def kernel(Out): layouts = [ - # MmaLayout(version=1, warps_per_cta=[1, 4]), - MmaLayout(version=(2, 0), warps_per_cta=[1, 4]), - # MmaLayout(version=1, warps_per_cta=[4, 1]), - MmaLayout(version=(2, 0), warps_per_cta=[4, 1]), - BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]), - BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]), - BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]), - BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]), - BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]), - BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]), - BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0]) + # MmaLayout(1, [1, 4], [1, 1], [0, 1]), + # MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), + # MmaLayout(1, [4, 1], [1, 1], [0, 1]), + # MmaLayout((2, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), + BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]) ] intermediate_layouts = [ None, - SharedLayout(1, 1, 1, [1, 0]), - SharedLayout(4, 2, 4, [1, 0]), - SharedLayout(2, 2, 4, [1, 0]), + SharedLayout(1, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(4, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(2, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), ] @@ -3235,7 +3484,7 @@ def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device): """ ir = layouts + """ - module attributes {"triton_gpu.num-warps" = 4 : i32} { + module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func public @kernel_0d1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { %cst = arith.constant dense<128> : tensor<128x1xi32, #src> %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>> diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 9ef4e99bf3c0..8add0a85e289 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -13,6 +13,7 @@ import triton import triton._C.libtriton.triton as _triton import triton.language as tl +from triton.common.build import is_hip from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret int_dtypes = ['int8', 'int16', 'int32', 'int64'] @@ -22,6 +23,13 @@ dtypes_with_bfloat16 = dtypes + ['bfloat16'] torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16'] +if is_hip(): + GPU_DIALECT = "triton_gpu" + THREADS_PER_WARP = 64 +else: + GPU_DIALECT = "triton_gpu" + THREADS_PER_WARP = 32 + def _bitwidth(dtype: str) -> int: # ex.: "int64" -> 64 @@ -922,6 +930,7 @@ def kernel(in_out_ptr): kernel[(65536,)](x, num_warps=16) assert torch.all(x == 2) + def convert_float_to_float32(fp: torch.tensor, dtype=None): if not dtype: dtype = getattr(tl, torch_dtype_name(fp.dtype)) @@ -941,8 +950,8 @@ def convert_float_to_float32(fp: torch.tensor, dtype=None): extended_exp = ((1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width # special cases, exp is 0b11..1 - if dtype in [tl.float8e4, tl.float8e4b15]: - # float8e4m3 does not have infinities + if dtype in [tl.float8e4nv, tl.float8e4b15]: + # float8e4m3nv does not have infinities output[fp == 0b01111111] = torch.nan output[fp == 0b11111111] = torch.nan else: @@ -969,43 +978,39 @@ def test_convert_float16_to_float32(in_dtype, device): def serialize_fp8(np_data, in_dtype): - return np_data -# def serialize_fp8(np_data, in_dtype): -# if in_dtype == tl.float8e4b15: -# # triton's f8e4b15 format is optimized for software emulation -# # as a result, each pack of 4xfp8 values: -# # s0b0s1b1s2b2s3b3 (for s, b sign and bits respectively) -# # is actually internally stored as -# # s0s2b0b2s1s3b1b3 -# # we apply the conversion here -# f8x4 = np_data.view(np.uint32) -# s = [(f8x4 & (0x80000000 >> i)) << i for i in range(0, 32, 8)] -# b = [(f8x4 & (0x7f000000 >> i)) << i for i in range(0, 32, 8)] -# signs = (s[0] >> 0) | (s[1] >> 16) | (s[2] >> 1) | (s[3] >> 17) -# bits = (b[0] >> 1) | (b[1] >> 17) | (b[2] >> 8) | (b[3] >> 24) -# # tensor of triton fp8 data -# return (signs | bits).view(np.int8) -# else: -# return np_data + if in_dtype == tl.float8e4b15x4: + # triton's f8e4b15 format is optimized for software emulation + # as a result, each pack of 4xfp8 values: + # s0b0s1b1s2b2s3b3 (for s, b sign and bits respectively) + # is actually internally stored as + # s0s2b0b2s1s3b1b3 + # we apply the conversion here + f8x4 = np_data.view(np.uint32) + s = [(f8x4 & (0x80000000 >> i)) << i for i in range(0, 32, 8)] + b = [(f8x4 & (0x7f000000 >> i)) << i for i in range(0, 32, 8)] + signs = (s[0] >> 0) | (s[1] >> 16) | (s[2] >> 1) | (s[3] >> 17) + bits = (b[0] >> 1) | (b[1] >> 17) | (b[2] >> 8) | (b[3] >> 24) + # tensor of triton fp8 data + return (signs | bits).view(np.int8) + else: + return np_data # inverse of `serialize_fp8` def deserialize_fp8(np_data, in_dtype): - return np_data -# def deserialize_fp8(np_data, in_dtype): -# if in_dtype == tl.float8e4b15: -# f8x4 = np_data.view(np.uint32) -# s = [(f8x4 & (0x80000000 >> i)) << i for i in [0, 16, 1, 17]] -# b = [(f8x4 & (0x7f000000 >> i)) << i for i in [1, 17, 8, 24]] -# signs = (s[0] >> 0) | (s[1] >> 8) | (s[2] >> 16) | (s[3] >> 24) -# bits = (b[0] >> 0) | (b[1] >> 8) | (b[2] >> 16) | (b[3] >> 24) -# return (signs | bits).view(np.int8) -# else: -# return np_data - - -@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4, tl.float8e5]) + if in_dtype == tl.float8e4b15x4: + f8x4 = np_data.view(np.uint32) + s = [(f8x4 & (0x80000000 >> i)) << i for i in [0, 16, 1, 17]] + b = [(f8x4 & (0x7f000000 >> i)) << i for i in [1, 17, 8, 24]] + signs = (s[0] >> 0) | (s[1] >> 8) | (s[2] >> 16) | (s[3] >> 24) + bits = (b[0] >> 0) | (b[1] >> 8) | (b[2] >> 16) | (b[3] >> 24) + return (signs | bits).view(np.int8) + else: + return np_data + + +@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4b15x4, tl.float8e4nv, tl.float8e5]) @pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32]) def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device): """ @@ -1014,6 +1019,7 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device): - conversion tri_fp8 = convert(input=tri_fp16, out=out_dtype) matches the original this is only possible if both conversions are correct """ + check_type_supported(in_dtype, device) check_type_supported(out_dtype, device) @triton.jit @@ -1043,11 +1049,152 @@ def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): copy_kernel[(1,)](tri_fp16, triton.reinterpret(ref_fp8, in_dtype), tri_fp16.shape[0], BLOCK_SIZE=1024) assert torch.all(tri_fp8 == ref_fp8) + +@pytest.mark.parametrize("M, N, K, a_type, b_type, out_dtype", + [(*shape, *ab_type, out_dtype) + for shape in [[128, 256, 32], + [128, 16, 32], + [32, 128, 64], + [128, 128, 64], + [64, 128, 128], + [32, 128, 64], + [64, 64, 32], + [32, 32, 128], + [128, 128, 64], + [64, 128, 128]] + for ab_type in [[tl.float8e4nv, tl.float16], + [tl.float8e4b15, tl.float16], + [tl.float8e4b15x4, tl.float16], + [tl.float8e5, tl.float16], + [tl.float16, tl.float8e4nv], + [tl.float16, tl.float8e4b15], + [tl.float16, tl.float8e4b15x4], + [tl.float16, tl.float8e5]] + for out_dtype in [torch.float16, torch.float32] + ]) +def test_gemm_fp816_mixed_inputs(M, N, K, a_type, b_type, out_dtype, device = 'cuda'): + + check_type_supported(out_dtype, device) + + @triton.jit + def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) + + @triton.jit + def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + compute_type:tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b, out_dtype=compute_type) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + def matmul(a, b, c_type): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + + if c_type == torch.float16: + comp_type = tl.float16 + else: + comp_type = tl.float32 + + + c = torch.empty((M, N), device = a.device, dtype=c_type) + # 1D launch kernel where each block gets its own program. + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + compute_type = comp_type, + BLOCK_SIZE_M=32, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=64, + GROUP_SIZE_M=4, + num_stages=1, + num_warps=2, + ) + + return c + + + def gen_input(M, N, d_type, seed, device='cuda'): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + if d_type == tl.float16: + input = torch.randn((M, K), dtype=torch.float16, device=device) + input_f16 = input + else: # d_type is float8 + f8_tensor = torch.randn((M, N), dtype=torch.float32, device='cuda') * 10 + f8_tensor = f8_tensor.to(torch.int8) + # keep only two bits of exponent to avoid overflow + f8_tensor = f8_tensor & 0b00111111 + input = triton.reinterpret(f8_tensor, d_type) + input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + n_elements = f8_tensor.numel() + copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024) + return input, input_f16 + + a, a_f16 = gen_input(M, K, a_type, 11, device=device) + b, b_f16 = gen_input(K, N, b_type, 22, device=device) + + # call torch function to compute gold + golden = torch.matmul(a_f16, b_f16) + + c = matmul(a, b, out_dtype) + torch.testing.assert_close(c.to(golden.dtype), golden, rtol=1e-2, atol=6e-2) + + # --------------- # test reduce # --------------- - - def get_reduced_dtype(dtype_str, op): if op in ('argmin', 'argmax'): return 'int32' @@ -1999,27 +2146,43 @@ def kernel(ptr, n_elements, num1, num2): # test if # ------------- +# TODO(Keren): if_exp_dynamic + -@pytest.mark.parametrize("if_type", ["if", "if_exp"]) -def test_if(if_type): +@pytest.mark.parametrize("if_type", ["if", "if_and_dynamic", "if_exp_static", "if_and_static"]) +def test_if(if_type, device): @triton.jit - def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr): + def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr, StaticVaue: tl.constexpr): pid = tl.program_id(0) cond = tl.load(Cond) if IfType == "if": - if pid % 2: + if pid % 2 == 0: + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_dynamic": + tl.store(Ret, tl.load(XTrue)) if pid % 2 == 0 else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_static": + tl.store(Ret, tl.load(XTrue)) if BoolVar else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and_dynamic": + if BoolVar and pid % 2 == 0: + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and_static": + if StaticVaue != 0 and StaticVaue != 0: tl.store(Ret, tl.load(XTrue)) else: tl.store(Ret, tl.load(XFalse)) - else: - tl.store(Ret, tl.load(XTrue)) if pid % 2 else tl.store(Ret, tl.load(XFalse)) - cond = torch.ones(1, dtype=torch.int32, device='cuda') - x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda') - x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda') - ret = torch.empty(1, dtype=torch.float32, device='cuda') - kernel[(1,)](cond, x_true, x_false, ret, if_type) + cond = torch.ones(1, dtype=torch.int32, device=device) + x_true = torch.tensor([3.14], dtype=torch.float32, device=device) + x_false = torch.tensor([1.51], dtype=torch.float32, device=device) + ret = torch.zeros(1, dtype=torch.float32, device=device) + + kernel[(1,)](cond, x_true, x_false, ret, if_type, True, 1) + assert torch.equal(ret, x_true) def test_num_warps_pow2(): @@ -2294,46 +2457,58 @@ def kernel(InitI, Bound, CutOff, OutI, OutJ): # ----------------------- # TODO: backend should be tested separately - class MmaLayout: - def __init__(self, version, warps_per_cta): + def __init__(self, version, warps_per_cta, ctas_per_cga, cta_split_num, cta_order, instr_shape): self.version = version self.warps_per_cta = str(warps_per_cta) + self.ctas_per_cga = str(ctas_per_cga) + self.cta_split_num = str(cta_split_num) + self.cta_order = str(cta_order) + self.instr_shape = str(instr_shape) def __str__(self): - return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}}}>" + return f"#{GPU_DIALECT}.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>" class MfmaLayout: - def __init__(self, non_k_dim, warps_per_cta, isTransposed): + def __init__(self, non_k_dim, warps_per_cta, is_transposed, ctas_per_cga, cta_split_num, cta_order): self.non_k_dim = str(non_k_dim) self.warps_per_cta = str(warps_per_cta) - self.isTransposed = str(isTransposed).lower() + self.is_transposed = str(is_transposed).lower() + self.ctas_per_cga = str(ctas_per_cga) + self.cta_split_num = str(cta_split_num) + self.cta_order = str(cta_order) def __str__(self): - return f"#triton_gpu.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}, isTransposed = {self.isTransposed}}}>" + return f"#{GPU_DIALECT}.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}, isTransposed = {self.is_transposed}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" class BlockedLayout: - def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order): + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): self.sz_per_thread = str(size_per_thread) self.threads_per_warp = str(threads_per_warp) self.warps_per_cta = str(warps_per_cta) self.order = str(order) + self.ctas_per_cga = str(ctas_per_cga) + self.cta_split_num = str(cta_split_num) + self.cta_order = str(cta_order) def __str__(self): - return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>" + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" class SharedLayout: - def __init__(self, vec, per_phase, max_phase, order): + def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order): self.vec = str(vec) self.per_phase = str(per_phase) self.max_phase = str(max_phase) self.order = str(order) + self.ctas_per_cga = str(ctas_per_cga) + self.cta_split_num = str(cta_split_num) + self.cta_order = str(cta_order) def __str__(self): - return f"#triton_gpu.shared<{{vec = {self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}}}>" + return f"#{GPU_DIALECT}.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" @pytest.mark.parametrize("vec_size", [2, 4]) @@ -2366,17 +2541,20 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB): if vec_size != 4: pytest.skip() - shared_a = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeA else [1, 0]) - shared_b = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeB else [1, 0]) + blocked = BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[8, 8], warps_per_cta=[4, 1], order=[1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1]) + shared_a = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeA else [1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1]) + shared_b = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeB else [1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1]) + mfma = MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1]) ir = f""" -#blocked = #triton_gpu.blocked<{{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}}> +#blocked = {blocked} #shared1 = {shared_a} #shared2 = {shared_b} +#mfma = {mfma} """ + """ -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { tt.func public @kernel_0d1d2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>> + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mfma> %cst_0 = arith.constant dense<32> : tensor<32x1xi32, #blocked> %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked> @@ -2398,12 +2576,12 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB): %17 = tt.addptr %16, %8 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> %18 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16, #blocked> %19 = triton_gpu.convert_layout %18 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #shared1> - %20 = triton_gpu.convert_layout %19 : (tensor<32x32xf16, #shared1>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>> + %20 = triton_gpu.convert_layout %19 : (tensor<32x32xf16, #shared1>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>> %21 = tt.load %13 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16, #blocked> %22 = triton_gpu.convert_layout %21 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #shared2> - %23 = triton_gpu.convert_layout %22 : (tensor<32x32xf16, #shared2>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>> - %24 = tt.dot %20, %23, %cst {allowTF32 = false} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, kDim = 8, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>> -> tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>> - %25 = triton_gpu.convert_layout %24 : (tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>>) -> tensor<32x32xf32, #blocked> + %23 = triton_gpu.convert_layout %22 : (tensor<32x32xf16, #shared2>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}>> + %24 = tt.dot %20, %23, %cst {allowTF32 = false} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}>> -> tensor<32x32xf32, #mfma> + %25 = triton_gpu.convert_layout %24 : (tensor<32x32xf32, #mfma>) -> tensor<32x32xf32, #blocked> %26 = arith.truncf %25 : tensor<32x32xf32, #blocked> to tensor<32x32xf16, #blocked> tt.store %17, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf16, #blocked> tt.return @@ -2451,28 +2629,28 @@ def _get_warp_size(): # MmaLayout(version=(2, 0), warps_per_cta=[1, 4]), # MmaLayout(version=1, warps_per_cta=[4, 1]), # MmaLayout(version=(2, 0), warps_per_cta=[4, 1]), - BlockedLayout([1, 2], [2, 32], [2, 2], [1, 0]), - BlockedLayout([2, 2], [4, 16], [2, 2], [1, 0]), - BlockedLayout([1, 1], [1, 64], [2, 2], [1, 0]), - BlockedLayout([4, 2], [16, 4], [1, 4], [0, 1]), - BlockedLayout([4, 2], [8, 8], [2, 2], [0, 1]), - BlockedLayout([1, 1], [32, 2], [2, 2], [0, 1]), - BlockedLayout([4, 2], [1, 64], [4, 1], [1, 0]) + BlockedLayout([1, 2], [2, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [4, 16], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 1], [1, 64], [2, 2], [1, 0],[1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 2], [16, 4], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 2], [8, 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 1], [32, 2], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 2], [1, 64], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]) ] else: layouts = [ # MmaLayout(version=1, warps_per_cta=[1, 4]), - MmaLayout(version=(2, 0), warps_per_cta=[1, 4]), + MmaLayout(version=(2, 0), warps_per_cta=[1, 4], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1]), # MmaLayout(version=1, warps_per_cta=[4, 1]), - MmaLayout(version=(2, 0), warps_per_cta=[4, 1]), - BlockedLayout([1, 2], [2, 16], [2, 2], [1, 0]), - BlockedLayout([2, 2], [4, 8], [2, 2], [1, 0]), - BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]), - BlockedLayout([4, 2], [16, 2], [1, 4], [0, 1]), - BlockedLayout([4, 2], [8, 4], [2, 2], [0, 1]), - BlockedLayout([4, 2], [4, 8], [2, 2], [0, 1]), - BlockedLayout([1, 1], [16, 2], [2, 2], [0, 1]), - BlockedLayout([4, 2], [1, 32], [4, 1], [1, 0]) + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1]), + BlockedLayout([1, 2], [2, 16], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [4, 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 2], [16, 2], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 2], [8, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 2], [4, 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 1], [16, 2], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 2], [1, 32], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]) ] @@ -2490,7 +2668,7 @@ def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'): #src = {src_layout} #dst = {dst_layout} """ + """ -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = """ + str(_get_warp_size()) + """ : i32} { +module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = """ + str(_get_warp_size()) + """ : i32} { tt.func public @kernel_0d1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { %cst = arith.constant dense<128> : tensor<128x1xi32, #src> %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>> @@ -2530,15 +2708,18 @@ def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'): if torch.version.hip is not None and _get_warp_size() == 64: layouts = [ - MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], isTransposed=True), - MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2], isTransposed=False), + MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]), + MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]), ] shapes = [[128, 32], [128, 128], [32, 128], [64, 64]] else: layouts = [ - BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0]), - BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1]), - MmaLayout(version=(2, 0), warps_per_cta=[4, 1]) + BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 4], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), + MmaLayout(version=(2, 0), warps_per_cta=[2, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), + MmaLayout(version=(3, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 16, 16]), ] shapes = [[128, 16], [128, 128], [32, 128]] @@ -2548,15 +2729,16 @@ def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'): @pytest.mark.parametrize("axis", [0, 1]) def test_reduce_layouts(M, N, src_layout, axis, device='cuda'): if torch.version.hip is not None and _get_warp_size() == 64: - if src_layout.isTransposed and axis == 0: + if src_layout.is_transposed and axis == 0: pytest.skip("Reduce along axis 0 is not supported in transposed mfma layout") rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" rdims_1d = f"{N}" if axis == 0 else f"{M}" store_range = "%7" if axis == 0 else "%1" + blocked = BlockedLayout([1, 1], [32, 1], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]) ir = f""" - #blocked = #triton_gpu.blocked<{{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}}> + #blocked = {blocked} #src = {src_layout} - module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = {_get_warp_size()} : i32}} {{ + module attributes {{"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = {_get_warp_size()} : i32}} {{ tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked> @@ -2612,17 +2794,17 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'): scan_layouts = [ - BlockedLayout([1, 4], [4, 16], [4, 1], [0, 1]), - BlockedLayout([1, 4], [8, 8], [4, 1], [0, 1]), - BlockedLayout([4, 1], [4, 16], [1, 4], [0, 1]), - BlockedLayout([2, 2], [4, 16], [2, 2], [0, 1]), - BlockedLayout([2, 2], [8, 8], [2, 2], [0, 1]), - - BlockedLayout([1, 4], [4, 16], [4, 1], [1, 0]), - BlockedLayout([1, 4], [8, 8], [4, 1], [1, 0]), - BlockedLayout([4, 1], [4, 16], [1, 4], [1, 0]), - BlockedLayout([2, 2], [4, 16], [2, 2], [1, 0]), - BlockedLayout([2, 2], [8, 8], [2, 2], [1, 0]), + BlockedLayout([1, 4], [4, 16], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [4, 16], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [4, 16], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [8, 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + + BlockedLayout([1, 4], [4, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [4, 16], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [4, 16], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [8, 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), ] @@ -2632,7 +2814,7 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'): def test_scan_layouts(M, N, src_layout, axis, device): ir = f""" #blocked = {src_layout} - module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32}} {{ + module attributes {{"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32}} {{ tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> @@ -2682,14 +2864,15 @@ def test_scan_layouts(M, N, src_layout, axis, device): @pytest.mark.parametrize("shape", [(64, 64)]) @pytest.mark.parametrize("dtype", ['float16']) -@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], isTransposed=False), MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], isTransposed=True)]) -@pytest.mark.parametrize("dst_layout", [BlockedLayout([1, 4], [4, 16], [1, 1], [1, 0])]) +@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]), + MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0])]) +@pytest.mark.parametrize("dst_layout", [BlockedLayout([1, 4], [4, 16], [1, 1], [1, 0], [1, 1], [1, 1], [0, 1])]) def test_make_range(dtype, shape, src_layout, dst_layout, device='cuda'): ir = f""" #src = {src_layout} #dst = {dst_layout} """ + """ -module attributes {"triton_gpu.num-warps" = """ + str(128 // _get_warp_size()) + """ : i32, "triton_gpu.threads-per-warp" = """ + str(_get_warp_size()) + """ : i32} { +module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = """ + str(128 // _get_warp_size()) + """ : i32, "triton_gpu.threads-per-warp" = """ + str(_get_warp_size()) + """ : i32} { tt.func public @kernel_0d1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { %cst = arith.constant dense<64> : tensor<64x1xi32, #src> %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #src}>> diff --git a/python/test/unit/language/test_line_info.py b/python/test/unit/language/test_line_info.py index 8a482d5f5cc3..2823cf9299b2 100644 --- a/python/test/unit/language/test_line_info.py +++ b/python/test/unit/language/test_line_info.py @@ -116,5 +116,3 @@ def test_line_info(func: str): assert (check_file_lines(file_lines, "standard.py", 33)) assert (check_file_lines(file_lines, "standard.py", 34)) assert (check_file_lines(file_lines, "standard.py", 36)) - # core.py is changed frequently, so we only check if it exists - assert (check_file_lines(file_lines, "core.py", -1)) diff --git a/python/test/unit/language/test_subprocess.py b/python/test/unit/language/test_subprocess.py index 0baf89fce851..78b8d09fb9f8 100644 --- a/python/test/unit/language/test_subprocess.py +++ b/python/test/unit/language/test_subprocess.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize("func_type, data_type", - [("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32")]) + [("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32"), ("no_arg_print", "int32")]) def test_print(func_type: str, data_type: str): proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, shell=False) outs, _ = proc.communicate() @@ -29,10 +29,9 @@ def test_print(func_type: str, data_type: str): new_lines.add(value) except Exception as e: print(e) - if func_type != "static_print": + if func_type != "static_print" and func_type != "no_arg_print": for i in range(128): assert i in new_lines - assert len(new_lines) == 128 else: assert len(new_lines) == 1 diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index b5f01388a002..8bf7f2e21ec9 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -13,6 +13,13 @@ @pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('seq_par', [True, False]) def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): + # with ENABLE_TMA=0 and ENABLE_MMA_V3=0 + import os + enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower() + enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower() + if enable_mmav3 in ["on", "true", "1"] and enable_tma in ["on", "true", "1"]: + pytest.skip('Segmentation fault') + capability = torch.cuda.get_device_capability() if torch.version.hip is not None: if dtype != torch.float16: diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index 4e441e3a57d8..a7afa02f10b3 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -26,93 +26,96 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): @pytest.mark.parametrize( - "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE", + "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32", itertools.chain( *[ [ # 1 warp - (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), # 2 warp - (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), # 4 warp - (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), # 8 warp - (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE), - # split-k - (64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), - (64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), # variable input - (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE), - (128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE, DTYPE), - (128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE, DTYPE), - (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE), + (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True), + (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True), + (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True), + (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True), ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] ], # n-stage *[ [ - (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE), - (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE), - (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE), - (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE), - (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE), - # split-k - (64, 64, 16, 8, 4, STAGES, 128, 128, 768, AT, BT, DTYPE, DTYPE), - (64, 64, 16, 8, 4, STAGES, 128, 128, 32, AT, BT, DTYPE, DTYPE), + (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True), + (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True), + (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True), + (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True), + (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True), ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4] ], # mixed-precision *[ [ - (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE), - (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE), - (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE), - (128, 128, 32, 8, 4, 2, 256, 256, 128, AT, BT, ADTYPE, BDTYPE), - ] for ADTYPE, BDTYPE in [("float8e4", "float8e5"), - ("float8e4", "float16"), + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True), + ] for ADTYPE, BDTYPE in [("float8e4nv", "float8e5"), + ("float8e4nv", "float8e4nv"), + ("float8e5", "float8e4nv"), + ("float8e5", "float8e5"), + ("float8e4b15", "float8e4b15"), + ("float8e4nv", "float16"), ("float16", "float8e5"), ("float16", "float32"), ("float32", "float16"), ("bfloat16", "float32"), ("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] ], + # mixed-precision block layout *[ - # float8e4b15 only supports row-col layout [ - (128, 128, 32, 1, 4, 2, None, None, None, False, True, ADTYPE, BDTYPE), - ] for ADTYPE, BDTYPE in [("float8e4b15", "float8e5"), - ("float8e4b15", "float16"), - ("float16", "float8e4b15")] - ] + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False), + ] for ADTYPE, BDTYPE in [("float8e4nv", "float16"), + ("float16", "float8e5"), + ("float16", "float32"), + ("float32", "float16"), + ("bfloat16", "float32"), + ("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] + ], ), ) -def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE): +def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32): capability = torch.cuda.get_device_capability() if capability[0] < 7: pytest.skip("Only test tl.dot() on devices with sm >= 70") if capability[0] < 8 and (ADTYPE == "bfloat16" or BDTYPE == "bfloat16"): pytest.skip("Only test bfloat16 on devices with sm >= 80") + if capability[0] < 9 and (ADTYPE == "float8e4nv" or BDTYPE == "float8e4nv"): + pytest.skip("Only test float8e4nv on devices with sm >= 90") if (ADTYPE == "bfloat16" or BDTYPE == "bfloat16") and SPLIT_K != 1: pytest.skip("bfloat16 matmuls don't allow split_k for now") torch.manual_seed(0) @@ -128,42 +131,52 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, M = BLOCK_M if M is None else M N = BLOCK_N if N is None else N K = BLOCK_K * SPLIT_K if K is None else K - a_fp8 = "float8" in ADTYPE - b_fp8 = "float8" in BDTYPE def maybe_upcast(x, dtype, is_float8): if is_float8: return f8_to_f16(x, dtype) return x - def init_input(n, m, t, dtype, is_float8): - if t: - return init_input(m, n, False, dtype, is_float8).t() - if is_float8: - return torch.randint(20, 50, (n, m), device="cuda", dtype=torch.int8) + def init_input(m, n, dtype): + if 'float8' in dtype: + ewidth = {'float8e4b15': 4, 'float8e4nv': 4, 'float8e5': 5}[dtype] + sign = torch.randint(2, size=(m, n), device="cuda", dtype=torch.int8) * 128 + val = torch.randint(2**3 - 1, size=(m, n), device="cuda", dtype=torch.int8) << 7 - ewidth + return sign | val + if dtype == "int8": + return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8) dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype] - return .1 * torch.randn((n, m), device="cuda", dtype=dtype) + exponents = torch.randint(-10, 0, size=(m, n)) + ret = (2. ** exponents).to(dtype).to("cuda") + return ret # allocate/transpose inputs - a = init_input(M, K, AT, ADTYPE, a_fp8) - b = init_input(K, N, BT, BDTYPE, b_fp8) + a = init_input(M, K, ADTYPE) + b = init_input(K, N, BDTYPE) + a = a if not AT else a.T.contiguous().T + b = b if not BT else b.T.contiguous().T # run test - th_a = maybe_upcast(a, ADTYPE, a_fp8).to(torch.float32) + a_fp8 = "float8" in ADTYPE + b_fp8 = "float8" in BDTYPE + th_a = maybe_upcast(a, ADTYPE, a_fp8) if AT and a_fp8: th_a = th_a.view(th_a.shape[::-1]).T - th_b = maybe_upcast(b, BDTYPE, b_fp8).to(torch.float32) + th_b = maybe_upcast(b, BDTYPE, b_fp8) if BT and b_fp8: th_b = th_b.view(th_b.shape[::-1]).T - th_c = torch.matmul(th_a, th_b) + if th_a.is_floating_point(): + ab_dtype = th_a.dtype if th_a.element_size() > th_b.element_size() else th_b.dtype + else: + ab_dtype = torch.float32 + th_c = torch.matmul(th_a.to(ab_dtype), th_b.to(ab_dtype)) + if ADTYPE == "int8" or BDTYPE == "int8": + th_c = th_c.to(torch.int8) try: if a_fp8: a = triton.reinterpret(a, getattr(tl, ADTYPE)) if b_fp8: b = triton.reinterpret(b, getattr(tl, BDTYPE)) - tt_c = triton.ops.matmul(a, b) - atol, rtol = 1e-2, 0 - if ADTYPE == torch.bfloat16 or BDTYPE == torch.bfloat16: - atol, rtol = 3.5e-2, 0 - torch.testing.assert_allclose(th_c, tt_c, atol=atol, rtol=rtol) + tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32) + torch.testing.assert_allclose(th_c, tt_c, atol=0, rtol=0) except triton.OutOfResources as e: pytest.skip(str(e)) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index e13921079992..6f9b94d907f2 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -70,7 +70,8 @@ def test_nested1_change(): def reset_tmp_dir(): os.environ["TRITON_CACHE_DIR"] = tmpdir if os.path.exists(tmpdir): - shutil.rmtree(tmpdir) + # https://stackoverflow.com/questions/303200/how-do-i-remove-delete-a-folder-that-is-not-empty + shutil.rmtree(tmpdir, ignore_errors=True) def test_reuse(): @@ -98,7 +99,7 @@ def inc_counter(*args, **kwargs): reset_tmp_dir() x = torch.empty(1, dtype=torch.int32, device='cuda') function = {'enable': kernel, 'disable': kernel_nospec}[mode] - target = {'enable': 3, 'disable': 1}[mode] + target = {'enable': 4, 'disable': 1}[mode] for i in [1, 2, 4, 8, 16, 32]: function[(1,)](x, i, BLOCK=512) assert counter == target diff --git a/python/test/unit/runtime/test_launch.py b/python/test/unit/runtime/test_launch.py index fc79ab67aac6..d3f9fd01bda8 100644 --- a/python/test/unit/runtime/test_launch.py +++ b/python/test/unit/runtime/test_launch.py @@ -38,7 +38,7 @@ def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): kernel[(10,)](inp, out, 10, XBLOCK=16) gc.collect() end, _ = tracemalloc.get_traced_memory() - assert end - begin < 5000 + assert end - begin < 30000 finally: tracemalloc.stop() diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index 49092bbd3cfe..f1958ffe20c0 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -14,10 +14,10 @@ def reset_tmp_dir(): os.environ["TRITON_CACHE_DIR"] = tmpdir if os.path.exists(tmpdir): - shutil.rmtree(tmpdir) + shutil.rmtree(tmpdir, ignore_errors=True) -instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"]) +instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"]) def get_device_type(): @@ -54,7 +54,7 @@ def kernel_sub(a, b, o, N: tl.constexpr): def test_compile_in_subproc() -> None: cc, device_type = get_device_type() - config = instance_descriptor(tuple(range(4)), ()) + config = instance_descriptor(tuple(range(4)), (), (), ()) multiprocessing.set_start_method('fork') proc = multiprocessing.Process( @@ -87,7 +87,7 @@ def kernel_dot(Z): def test_compile_in_forked_subproc() -> None: reset_tmp_dir() cc, device_type = get_device_type() - config = instance_descriptor(tuple(range(1)), ()) + config = instance_descriptor(tuple(range(1)), (), (), ()) assert multiprocessing.get_start_method() == 'fork' proc = multiprocessing.Process( diff --git a/python/test/unit/tools/test_aot.py b/python/test/unit/tools/test_aot.py index 69919a265200..2c6fe88a2615 100644 --- a/python/test/unit/tools/test_aot.py +++ b/python/test/unit/tools/test_aot.py @@ -23,26 +23,43 @@ def mul(x, y): import kernel_utils @triton.jit -def kernel(C, A, B, +def kernel(C, A, B, M, N, K, stride_cm, stride_cn, stride_am, stride_ak, stride_bk, stride_bn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr): - ms = tl.arange(0, BLOCK_M) - ns = tl.arange(0, BLOCK_N) - ks = tl.arange(0, BLOCK_K) - a = tl.load(A + ms[:, None] * stride_am + ks[None, :] * stride_ak) - b = tl.load(B + ks[:, None] * stride_bk + ns[None, :] * stride_bn) - c = tl.dot(a, b) - c = kernel_utils.mul(c, c) - tl.store(C + ms[:, None] * stride_cm + ns[None, :] * stride_cn, c) + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + offs_k = tl.arange(0, BLOCK_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + c = kernel_utils.mul(accumulator, accumulator) + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(c_ptrs, c) """ - -def gen_test_bin(dir, M, N, K, BM, BN, BK): - test_src = ''' +test_utils_src = ''' #include #include #include @@ -78,10 +95,23 @@ def gen_test_bin(dir, M, N, K, BM, BN, BK): fclose(file); }''' - test_src += f''' + +def gen_kernel_library(dir, libname): + c_files = glob.glob(os.path.join(dir, "*.c")) + subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(), + "-c", "-fPIC"], + check=True, cwd=dir) + o_files = glob.glob(os.path.join(dir, "*.o")) + subprocess.run(["gcc"] + o_files + ["-shared", + "-o", libname, + "-L", libcuda_dirs()[0]], + check=True, cwd=dir) + + +def gen_test_bin(dir, M, N, K, exe="test", algo_id=0): + test_src = f''' int main(int argc, char **argv) {{ int M = {M}, N = {N}, K = {K}; - int BM = {M}, BN = {N}, BK = {K}; // initialize CUDA handles CUdevice dev; @@ -96,7 +126,7 @@ def gen_test_bin(dir, M, N, K, BM, BN, BK): cuMemAlloc(&B, K * N * 2); cuMemAlloc(&C, M * N * 4); cuStreamCreate(&stream, 0); - load_matmul_fp16xfp16_16x16x16(); + load_matmul_fp16(); // initialize input data int16_t hA[M*K]; @@ -110,7 +140,13 @@ def gen_test_bin(dir, M, N, K, BM, BN, BK): // launch kernel cuStreamSynchronize(stream); - CUresult ret = matmul_fp16xfp16_16x16x16(stream, M/BM, N/BN, 1, C, A, B, N, 1, K, 1, N, 1); + CUresult ret; + int algo_id = {algo_id}; + if (algo_id == 0) {{ + ret = matmul_fp16_default(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1); + }} else {{ + ret = matmul_fp16(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1, {algo_id}); + }} if (ret != 0) fprintf(stderr, "kernel launch failed\\n"); assert(ret == 0); @@ -123,41 +159,51 @@ def gen_test_bin(dir, M, N, K, BM, BN, BK): write_buffer_to_csv(argv[3], hC, M*N); // free cuda handles - unload_matmul_fp16xfp16_16x16x16(); + unload_matmul_fp16(); cuMemFree(A); cuMemFree(B); cuMemFree(C); cuCtxDestroy(ctx); }} ''' - + src = test_utils_src + test_src with open(os.path.join(dir, "test.c"), "w") as file: - file.write(test_src) - c_files = glob.glob(os.path.join(dir, "*.c")) - subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(), - "-L", libcuda_dirs()[0], - "-l", "cuda", - "-o", "test"], check=True, cwd=dir) + file.write(src) + subprocess.run(["gcc"] + ["test.c", + "-I", cuda_include_dir(), + "-L", libcuda_dirs()[0], + "-l", "cuda", + "-L", dir, + "-l", "kernel", + "-o", exe], check=True, cwd=dir) -def generate_matmul_launcher(dir, dtype, BM, BN, BK, ha_hb_hints): +def write_triton_kernels(dir, src, util_src): kernel_path = os.path.join(dir, "kernel.py") with open(kernel_path, "w") as file: - file.write(kernel_src) + file.write(src) kernel_utils_path = os.path.join(dir, "kernel_utils.py") with open(kernel_utils_path, "w") as file: - file.write(kernel_utils_src) + file.write(util_src) + return kernel_path + + +def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints): compiler_path = os.path.join(triton.tools.__path__[0], "compile.py") - linker_path = os.path.join(triton.tools.__path__[0], "link.py") # compile all desired configs for ha in ha_hb_hints: for hb in ha_hb_hints: - sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}' - name = f"matmul_{dtype}x{dtype}_{BM}x{BN}x{BK}" - subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", kernel_path], check=True, cwd=dir) + sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}' + name = f"matmul_{dtype}" + grid = f'M/{BM}, N/{BN}, 1' + subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", "-g", grid, kernel_path], check=True, cwd=dir) + + +def link_aot_kernels(dir): + linker_path = os.path.join(triton.tools.__path__[0], "link.py") # link all desired configs h_files = glob.glob(os.path.join(dir, "*.h")) @@ -183,17 +229,22 @@ def test_compile_link_matmul(): dtype = "fp16" BM, BN, BK = 16, 16, 16 - generate_matmul_launcher(tmp_dir, dtype, BM, BN, BK, ha_hb_hints=["", ":16"]) + kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src) + compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"]) + link_aot_kernels(tmp_dir) # compile test case M, N, K = 16, 16, 16 - gen_test_bin(tmp_dir, M, N, K, BM, BN, BK) + gen_kernel_library(tmp_dir, "libkernel.so") + gen_test_bin(tmp_dir, M, N, K) # initialize test data a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K) # run test case - subprocess.run(["./test", a_path, b_path, c_path], check=True, cwd=tmp_dir) + env = os.environ.copy() + env["LD_LIBRARY_PATH"] = tmp_dir + subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir) # read data and compare against reference c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) @@ -209,26 +260,76 @@ def test_launcher_has_no_available_kernel(): dtype = "fp16" BM, BN, BK = 16, 16, 16 - generate_matmul_launcher(tmp_dir, dtype, BM, BN, BK, ha_hb_hints=[":1"]) + kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src) + compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=[":1"]) + link_aot_kernels(tmp_dir) # compile test case M, N, K = 16, 16, 16 - gen_test_bin(tmp_dir, M, N, K, BM, BN, BK) + gen_kernel_library(tmp_dir, "libkernel.so") + gen_test_bin(tmp_dir, M, N, K) # initialize test data a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K) # run test case - result = subprocess.run(["./test", a_path, b_path, c_path], cwd=tmp_dir, capture_output=True, text=True) + env = os.environ.copy() + env["LD_LIBRARY_PATH"] = tmp_dir + result = subprocess.run(["./test", a_path, b_path, c_path], env=env, cwd=tmp_dir, capture_output=True, text=True) # It should fail since the launcher requires all the strides be 1 while they are not. assert result.returncode == -6 assert "kernel launch failed" in result.stderr +def test_compile_link_autotune_matmul(): + np.random.seed(3) + + with tempfile.TemporaryDirectory() as tmp_dir: + + dtype = "fp16" + + kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src) + + tile_sizes = [ + [16, 16, 16], + [32, 32, 16], + [32, 32, 32], + [64, 64, 32], + ] + + for ts in tile_sizes: + BM, BN, BK = ts[0], ts[1], ts[2] + compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"]) + + link_aot_kernels(tmp_dir) + + gen_kernel_library(tmp_dir, "libkernel.so") + + # compile test case + M, N, K = 64, 64, 64 + # initialize test data + a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K) + c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32)) + + for algo_id in range(len(tile_sizes)): + # generate and run test case + test_name = f"test_{algo_id}" + gen_test_bin(tmp_dir, M, N, K, exe=test_name, algo_id=algo_id) + + env = os.environ.copy() + env["LD_LIBRARY_PATH"] = tmp_dir + subprocess.run([f"./{test_name}", a_path, b_path, c_path], check=True, cwd=tmp_dir, env=env) + + # read data and compare against reference + c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) + c_tri = c.reshape((M, N)).view(np.float32) + np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=1e-4) + + def test_ttgir_to_ptx(): src = """ -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} { tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr, %arg1: !tt.ptr) { tt.return } diff --git a/python/triton/common/build.py b/python/triton/common/build.py index 2e1e1d92b761..e7948e7b2171 100644 --- a/python/triton/common/build.py +++ b/python/triton/common/build.py @@ -18,7 +18,7 @@ def is_hip(): @functools.lru_cache() def libcuda_dirs(): - libs = subprocess.check_output(["ldconfig", "-p"]).decode() + libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() # each line looks like the following: # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1 locs = [line.split()[-1] for line in libs.splitlines() if "libcuda.so" in line] @@ -27,13 +27,22 @@ def libcuda_dirs(): if locs: msg += 'Possible files are located at %s.' % str(locs) msg += 'Please create a symlink of libcuda.so to any of the file.' + else: + msg += 'Please make sure GPU is setup and then run "/sbin/ldconfig"' + msg += ' (requires sudo) to refresh the linker cache.' assert any(os.path.exists(os.path.join(path, 'libcuda.so')) for path in dirs), msg return dirs @functools.lru_cache() def rocm_path_dir(): - return os.getenv("ROCM_PATH", default="/opt/rocm") + default_path = os.path.join(os.path.dirname(__file__), "..", "third_party", "rocm") + # Check if include files have been populated locally. If so, then we are + # most likely in a whl installation and he rest of our libraries should be here + if (os.path.exists(default_path+"/include/hip/hip_runtime.h")): + return default_path + else: + return os.getenv("ROCM_PATH", default="/opt/rocm") @contextlib.contextmanager diff --git a/python/triton/compiler/__init__.py b/python/triton/compiler/__init__.py index 9593deafbdb9..89f46a1fe1ca 100644 --- a/python/triton/compiler/__init__.py +++ b/python/triton/compiler/__init__.py @@ -1,4 +1,5 @@ -from .compiler import CompiledKernel, compile, instance_descriptor +from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, + get_arch_default_num_warps, instance_descriptor) from .errors import CompilationError -__all__ = ["compile", "instance_descriptor", "CompiledKernel", "CompilationError"] +__all__ = ["compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", "get_arch_default_num_stages"] diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index e9bca5a076c4..328474642aea 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -21,16 +21,8 @@ def mangle_ty(ty): SIGNED = language.dtype.SIGNEDNESS.SIGNED prefix = 'i' if ty.int_signedness == SIGNED else 'u' return prefix + str(ty.int_bitwidth) - if ty.is_fp8(): - return 'fp8' - if ty.is_fp16(): - return 'fp16' - if ty.is_bf16(): - return 'bf16' - if ty.is_fp32(): - return 'fp32' - if ty.is_fp64(): - return 'fp64' + if ty.is_floating(): + return str(ty) if ty.is_block(): elt = mangle_ty(ty.scalar) shape = '_'.join(map(str, ty.shape)) @@ -64,6 +56,10 @@ def _is_triton_scalar(o: Any) -> bool: return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1) +def _is_list_like(o: Any) -> bool: + return isinstance(o, (list, tuple)) + + def _unwrap_if_constexpr(o: Any): return o.value if isinstance(o, constexpr) else o @@ -284,6 +280,9 @@ def _set_insertion_point_and_loc(self, ip, loc): # AST visitor # def visit_compound_statement(self, stmts): + # Ensure that stmts is iterable + if not _is_list_like(stmts): + stmts = [stmts] for stmt in stmts: ret_type = self.visit(stmt) if ret_type is not None and isinstance(stmt, ast.Return): @@ -350,7 +349,8 @@ def visit_FunctionDef(self, node): continue else: if i in self.attributes: - fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i][1]) + for name, value in self.attributes[i]: + fn.set_arg_attr(idx, name, value) arg_values.append(tensor(fn.args(idx), self.prototype.param_types[idx])) idx += 1 @@ -412,9 +412,9 @@ def visit_Assign(self, node): raise UnsupportedLanguageConstruct(None, node, "simultaneous multiple assignment is not supported.") names = _names[0] values = self.visit(node.value) - if not isinstance(names, tuple): + if not _is_list_like(names): names = [names] - if not isinstance(values, tuple): + if not _is_list_like(values): values = [values] native_nontensor_types = (language.dtype, ) for name, value in zip(names, values): @@ -496,7 +496,7 @@ def visit_then_else_blocks(self, node, liveins, then_block, else_block): # check type for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]: if name in defs: - assert defs[name].type == liveins[name].type,\ + assert defs[name].type == liveins[name].type, \ f'initial value for `{name}` is of type {liveins[name].type}, '\ f'but the {block_name} block redefines it as {defs[name].type}' if name in then_defs or name in else_defs: @@ -516,7 +516,7 @@ def visit_then_else_blocks(self, node, liveins, then_block, else_block): continue then_ty = then_defs[name].type else_ty = else_defs[name].type - assert then_ty == else_ty,\ + assert then_ty == else_ty, \ f'mismatched type for {name} between then block ({then_ty}) '\ f'and else block ({else_ty})' names.append(name) @@ -618,11 +618,19 @@ def visit_If(self, node): def visit_IfExp(self, node): cond = self.visit(node.test) if _is_triton_tensor(cond): - cond = cond.to(language.int1, _builder=self.builder) - if _unwrap_if_constexpr(cond): - return self.visit(node.body) + raise UnsupportedLanguageConstruct( + None, node, + "Triton does not support `if` expressions (ternary operators) with dynamic conditions, use `if` statements instead") else: - return self.visit(node.orelse) + cond = _unwrap_if_constexpr(cond) + if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks + raise UnsupportedLanguageConstruct( + None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), type(cond).__name__)) + if cond: + return self.visit(node.body) + else: + return self.visit(node.orelse) def visit_Pass(self, node): pass @@ -814,7 +822,7 @@ def visit_For(self, node): if name in liveins: assert _is_triton_tensor(self.local_defs[name]), f'{name} is not tensor' assert _is_triton_tensor(liveins[name]) - assert self.local_defs[name].type == liveins[name].type,\ + assert self.local_defs[name].type == liveins[name].type, \ f'Loop-carried variable {name} has initial type {liveins[name].type} '\ f'but is re-assigned to {self.local_defs[name].type} in loop! '\ f'Please make sure that the type stays consistent.' @@ -1061,9 +1069,10 @@ def str_to_ty(name): ty = str_to_ty(name[1:]) return language.pointer_type(ty) tys = { - "fp8e4": language.float8e4, + "fp8e4nv": language.float8e4nv, "fp8e5": language.float8e5, "fp8e4b15": language.float8e4b15, + "fp8e4b15x4": language.float8e4b15x4, "fp16": language.float16, "bf16": language.bfloat16, "fp32": language.float32, @@ -1084,7 +1093,7 @@ def str_to_ty(name): def kernel_suffix(signature, specialization): # suffix format: - # <'c' if equal to 1><'d' if divisible by 16> + # <'c' if equal to 1><'d' if divisible by 16><'e' if divisible by 8> suffix = '' for i, _ in enumerate(signature): suffix += str(i) @@ -1092,6 +1101,8 @@ def kernel_suffix(signature, specialization): suffix += 'c' if i in specialization.divisible_by_16: suffix += 'd' + if i in specialization.divisible_by_8: + suffix += 'e' return suffix @@ -1109,7 +1120,12 @@ def ast_to_ttir(fn, signature, specialization, constants, debug, arch): function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)]) tys = list(signature.values()) new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in specialization.equal_to_1} - new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16} + new_attrs = {k: [("tt.divisibility", 16)] for k in specialization.divisible_by_16} + for k in specialization.divisible_by_8: + attr = new_attrs[k] if k in new_attrs else [] + attr.append(("tt.max_divisibility", 8)) + new_attrs[k] = attr + all_constants = constants.copy() all_constants.update(new_constants) arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants] diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index ab7832258341..ae06583ac5bb 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -11,8 +11,9 @@ from pathlib import Path from typing import Any, Tuple -from .._C.libtriton.triton import (add_external_libs, compile_ptx_to_cubin, - get_shared_memory_size, ir, +from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, + compile_ptx_to_cubin, get_env_vars, get_num_warps, + get_shared_memory_size, ir, runtime, translate_llvmir_to_hsaco, translate_llvmir_to_ptx, translate_triton_gpu_to_llvmir, get_arch_info, get_warp_size) @@ -27,6 +28,8 @@ from ..tools.disasm import extract from .code_generator import ast_to_ttir from .make_launcher import make_stub +from .utils import (InfoFromBackendForTensorMap, TensorMapManager, + get_ids_of_tensormaps, parse_tma_info) CUDA_DEFAULT_WARP_SIZE = 32 @@ -70,18 +73,27 @@ def optimize_ttir(mod, arch): return mod -def ttir_to_ttgir(mod, num_warps, warpsize): +def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, arch): pm = ir.pass_manager(mod.context) pm.enable_debug() - pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize) + if is_hip(): + pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize, num_ctas, 0) + else: + pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize, num_ctas, arch) pm.run(mod) return mod -def optimize_ttgir(mod, num_stages, arch): +def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch, + cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue): pm = ir.pass_manager(mod.context) pm.enable_debug() pm.add_tritongpu_coalesce_pass() + # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass + pm.add_plan_cta_pass(cluster_info) + if _is_cuda(arch): + pm.add_tritongpu_rewrite_tensor_pointer_pass(arch) + pm.add_plan_cta_pass(cluster_info) pm.add_tritongpu_remove_layout_conversions_pass() if _is_cuda(arch): pm.add_tritongpu_accelerate_matmul_pass(arch) @@ -90,20 +102,55 @@ def optimize_ttgir(mod, num_stages, arch): matrix_core_version = gpu_matrix_core_version() pm.add_tritongpu_accelerate_matmul_pass(matrix_core_version) pm.add_tritongpu_remove_layout_conversions_pass() + if optimize_epilogue: + pm.add_tritongpu_optimize_epilogue_pass() pm.add_tritongpu_optimize_dot_operands_pass() if num_stages == 0 and is_hip() and gpu_matrix_core_version() != 0: pm.add_tritongpu_stream_pipeline_pass() pm.add_canonicalizer_pass() + ws_enabled = False + # `num_warps` does not mean the total number of warps of a CTA when + # warp specialization is enabled. + # it's the responsibility of the compiler to figure out the exact + # `num_warps` to use. + # TODO: support the case where `num_warps` from user is not 4. + if _is_cuda(arch) and arch // 10 >= 9 and enable_warp_specialization and num_warps == 4: + pm.add_tritongpu_ws_feasibility_checking_pass(arch) + pm.run(mod) + ws_enabled = ir.is_ws_supported(mod) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + if ws_enabled: + pm.add_tritongpu_wsdecomposing_pass(arch) + pm.add_tritongpu_wspipeline_pass( + num_stages, num_warps, arch) + pm.add_tritongpu_wsmutex_pass(arch) + pm.add_tritongpu_wsmaterialization_pass(arch) + pm.add_cse_pass() + else: + if is_hip(): + pm.add_tritongpu_pipeline_pass( + num_stages, num_warps, num_ctas, 0) + else: + pm.add_tritongpu_pipeline_pass( + num_stages, num_warps, num_ctas, arch) + if is_hip(): + pm.add_tritongpu_materialize_load_store_pass(num_warps, 0) else: - pm.add_tritongpu_pipeline_pass(num_stages) - pm.add_tritongpu_prefetch_pass() + pm.add_tritongpu_materialize_load_store_pass(num_warps, arch) + if _is_cuda(arch) and arch // 10 <= 8: + pm.add_tritongpu_prefetch_pass() pm.add_tritongpu_optimize_dot_operands_pass() pm.add_tritongpu_remove_layout_conversions_pass() pm.add_tritongpu_decompose_conversions_pass() + pm.add_tritongpu_ws_fixup_missing_attrs_pass() if num_stages != 0: pm.add_tritongpu_reorder_instructions_pass() pm.add_cse_pass() pm.add_symbol_dce_pass() + if _is_cuda(arch) and arch // 10 >= 9: + pm.add_tritongpu_fence_insertion_pass() + pm.add_tritongpu_ws_fixup_missing_attrs_pass() pm.run(mod) return mod @@ -115,14 +162,14 @@ def _add_external_libs(mod, libs): add_external_libs(mod, list(libs.keys()), list(libs.values())) -def ttgir_to_llir(mod, extern_libs, arch): +def ttgir_to_llir(mod, extern_libs, arch, tma_infos, waves_per_eu=0): if extern_libs: _add_external_libs(mod, extern_libs) # TODO: separate tritongpu_to_llvmir for different backends if _is_cuda(arch): - return translate_triton_gpu_to_llvmir(mod, arch, False) + return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM, waves_per_eu) else: - return translate_triton_gpu_to_llvmir(mod, 0, True) + return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu) # PTX translation @@ -201,7 +248,7 @@ def get_amdgcn_bitcode_paths(arch): def get_amdgpu_arch_fulldetails(): """ - get the amdgpu fulll ISA details for compiling: + get the amdgpu full ISA details for compiling: i.e., arch_triple: amdgcn-amd-amdhsa; arch_name: gfx906; arch_features: sramecc+:xnack- """ try: @@ -215,7 +262,8 @@ def get_amdgpu_arch_fulldetails(): arch_features = "" return [arch_triple, arch_name, arch_features, warp_size] - except BaseException: + except BaseException as e: + print("Error: Attempting to get amgpu ISA Details {}".format(e)) return None @@ -253,18 +301,23 @@ def convert_type_repr(x): return x -def make_hash(fn, arch, **kwargs): +def make_hash(fn, arch, env_vars, **kwargs): if isinstance(fn, JITFunction): configs = kwargs["configs"] signature = kwargs["signature"] constants = kwargs.get("constants", dict()) num_warps = kwargs.get("num_warps", 4) + num_ctas = kwargs.get("num_ctas", 1) num_stages = kwargs.get("num_stages", 3) + waves_per_eu = kwargs.get("waves_per_eu", 0) + enable_warp_specialization = kwargs.get("enable_warp_specialization", False) + enable_persistent = kwargs.get("enable_persistent", False) debug = kwargs.get("debug", False) # Get unique key for the compiled code - get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1)) + get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8)) configs_key = [get_conf_key(conf) for conf in configs] - key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}-{arch}" + env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())] + key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}" return hashlib.md5(key.encode("utf-8")).hexdigest() assert isinstance(fn, str) return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest() @@ -317,7 +370,7 @@ def parse_mlir_module(path, context): return module -instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()]) +instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], defaults=[set(), set(), set(), set()]) # TODO: architecture descriptor class @@ -348,6 +401,32 @@ def get_architecture_descriptor(capability): return capability +def get_arch_default_num_warps(device_type): + if device_type in ["cuda", "hip"]: + num_warps = 4 + else: + _device_backend = get_backend(device_type) + assert _device_backend + arch = _device_backend.get_architecture_descriptor() + num_warps = arch["num_warps"] + + return num_warps + + +def get_arch_default_num_stages(device_type, capability=None): + if device_type in ["cuda", "hip"]: + arch = get_architecture_descriptor(capability) + is_cuda = device_type == "cuda" and _is_cuda(arch) + num_stages = 3 if is_cuda and arch >= 75 else 2 + else: + _device_backend = get_backend(device_type) + assert _device_backend + arch = _device_backend.get_architecture_descriptor() + num_stages = arch["num_stages"] + + return num_stages + + def add_rocm_stages(arch, extern_libs, stages): extern_libs.update(get_amdgcn_bitcode_paths(arch)) @@ -377,9 +456,10 @@ def compile(fn, **kwargs): # Get device type to decide which backend should be used device_type = kwargs.get("device_type", "cuda") _device_backend = get_backend(device_type) + capability = kwargs.get("cc", None) if device_type in ["cuda", "hip"]: - arch = get_architecture_descriptor(kwargs.get("cc", None)) + arch = get_architecture_descriptor(capability) else: _device_backend = get_backend(device_type) assert _device_backend @@ -390,27 +470,48 @@ def compile(fn, **kwargs): warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch[3] context = ir.context() constants = kwargs.get("constants", dict()) - num_warps = kwargs.get("num_warps", 4) - num_stages = kwargs.get("num_stages", 3 if is_cuda and arch >= 75 else (1 if is_hip else 2)) + num_warps = kwargs.get("num_warps", get_arch_default_num_warps(device_type)) + assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2" + num_ctas = kwargs.get("num_ctas", 1) + num_stages = kwargs.get("num_stages", get_arch_default_num_stages(device_type, capability=capability)) + waves_per_eu = kwargs.get("waves_per_eu", 0) + # TODO[shuhaoj]: Default should be to enable warp specialization once possible + enable_warp_specialization = kwargs.get("enable_warp_specialization", False) + # TODO[shuhaoj]: persistent can be decoupled with warp specialization + enable_persistent = kwargs.get("enable_persistent", enable_warp_specialization) extern_libs = kwargs.get("extern_libs", dict()) if extern_libs is None: extern_libs = dict() debug = kwargs.get("debug", False) - + # Flag to control whether to store mma layout directly + optimize_epilogue = False + if os.environ.get('OPTIMIZE_EPILOGUE', '') == '1': + optimize_epilogue = True + # + cluster_info = ClusterInfo() + if "clusterDims" in kwargs: + cluster_info.clusterDimX = kwargs["clusterDims"][0] + cluster_info.clusterDimY = kwargs["clusterDims"][1] + cluster_info.clusterDimZ = kwargs["clusterDims"][2] + tma_infos = TMAInfos() # build compilation stages stages = dict() stages["ast"] = (lambda path: fn, None) stages["ttir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch)) stages["ttgir"] = (lambda path: parse_mlir_module(path, context), - lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size), num_stages, arch)) + lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue)) stages["llir"] = (lambda path: Path(path).read_text(), - lambda src: ttgir_to_llir(src, extern_libs, arch)) + lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos, waves_per_eu)) if is_cuda: add_cuda_stages(arch, extern_libs, stages) elif is_hip: add_rocm_stages(arch, extern_libs, stages) else: + # pass the user's configuration to the backend device. + arch["num_warps"] = num_warps + arch["num_stages"] = num_stages + arch["num_ctas"] = num_ctas _device_backend.add_stages(arch, extern_libs, stages) # find out the signature of the function @@ -443,14 +544,8 @@ def compile(fn, **kwargs): signature = {k: v for k, v in enumerate(param_tys)} first_stage = list(stages.keys()).index(ir_name) - # cache manager - if is_cuda or is_hip: - so_path = make_stub(name, signature, constants) - else: - so_path = _device_backend.make_launcher_stub(name, signature, constants) - # create cache manager - fn_cache_manager = get_cache_manager(make_hash(fn, arch, **kwargs)) + fn_cache_manager = get_cache_manager(make_hash(fn, arch, get_env_vars(), **kwargs)) # determine name and extension type of provided function if isinstance(fn, JITFunction): name, ext = fn.__name__, "ast" @@ -471,13 +566,21 @@ def compile(fn, **kwargs): if metadata_path is not None: with open(metadata_path) as f: metadata = json.load(f) + if 'tensormaps_info' in metadata: + metadata['tensormaps_info'] = [ + InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']] else: metadata = {"num_warps": num_warps, "warp_size": warp_size, + "num_ctas": num_ctas, "num_stages": num_stages, + "waves_per_eu": waves_per_eu, + "enable_warp_specialization": enable_warp_specialization, + "enable_persistent": enable_persistent, "constants": _get_jsonable_constants(constants), "debug": debug, "arch": arch, } + metadata.update(get_env_vars()) if ext == "ptx": assert "shared" in kwargs, "ptx compilation must provide shared memory size" metadata["shared"] = kwargs["shared"] @@ -522,6 +625,10 @@ def compile(fn, **kwargs): asm[ir_name] = str(next_module) if ir_name == "llir" and "shared" not in metadata: metadata["shared"] = get_shared_memory_size(module) + if ir_name == "ttgir": + metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module) + if metadata["enable_warp_specialization"]: + metadata["num_warps"] = get_num_warps(next_module) if ir_name == "ptx": metadata["name"] = get_kernel_name(next_module, pattern='// .globl') if ir_name == "amdgcn": @@ -530,10 +637,36 @@ def compile(fn, **kwargs): if not is_cuda and not is_hip: _device_backend.add_meta_info(ir_name, module, next_module, metadata, asm) module = next_module + + ids_of_folded_args = tuple([int(k) for k in configs[0].ids_of_folded_args]) if isinstance(fn, JITFunction) else () + if "clusterDims" not in metadata: + metadata["clusterDims"] = [ + cluster_info.clusterDimX, + cluster_info.clusterDimY, + cluster_info.clusterDimZ] + + if len(tma_infos) > 0: + metadata["tensormaps_info"] = parse_tma_info(tma_infos, ids_of_folded_args) + # set constant + if "tensormaps_info" in metadata: + for i, _ in enumerate(metadata["tensormaps_info"]): + metadata["tensormaps_info"][i].ids_of_folded_args = ids_of_folded_args + + ids_of_tensormaps = get_ids_of_tensormaps(metadata.get("tensormaps_info", None)) + if isinstance(fn, JITFunction) and "tensormaps_info" in metadata: + fn.tensormaps_info = metadata["tensormaps_info"] + + ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else () + ids = {"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": ids_of_const_exprs} + # cache manager + if is_cuda or is_hip: + so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization) + else: + so_path = _device_backend.make_launcher_stub(name, signature, constants, ids) # write-back metadata, if it didn't come from the cache if metadata_path is None: - metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata), metadata_filename, binary=False) - fn_cache_manager.put_group(metadata_filename, metadata_group) + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False) + fn_cache_manager.put_group(metadata_filename, metadata_group) # return handle to compiled kernel return CompiledKernel(fn, so_path, metadata, asm) @@ -544,6 +677,7 @@ class CompiledKernel: # Hooks for external tools to monitor the execution of triton kernels launch_enter_hook = None launch_exit_hook = None + tensormap_manager = TensorMapManager() def __init__(self, fn, so_path, metadata, asm): # initialize launcher @@ -554,10 +688,15 @@ def __init__(self, fn, so_path, metadata, asm): spec.loader.exec_module(mod) self.c_wrapper = getattr(mod, "launch") # initialize metadata - self.shared = metadata["shared"] if "shared" in metadata else 0 + self.shared = metadata["shared"] self.num_warps = metadata["num_warps"] self.warp_size = metadata["warp_size"] + self.num_ctas = metadata["num_ctas"] self.num_stages = metadata["num_stages"] + self.waves_per_eu = metadata["waves_per_eu"] + self.clusterDims = metadata["clusterDims"] + if "tensormaps_info" in metadata: + self.tensormaps_info = metadata["tensormaps_info"] self.constants = metadata["constants"] self.device_type = metadata["device_type"] self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda", "hip"] else None @@ -604,17 +743,29 @@ def __getattribute__(self, name): self._init_handles() return super().__getattribute__(name) + # capture args and expand args with cutensormap* + def assemble_tensormap_to_arg(self, args): + args_with_tma = list(args) + if hasattr(self, 'tensormaps_info'): + # tuple for hashable + args_ptr = tuple([arg.data_ptr() if hasattr(arg, 'data_ptr') else arg for arg in args]) + for i, e in enumerate(self.tensormaps_info): + args_with_tma.append(CompiledKernel.tensormap_manager[(e, args_ptr)]) + return args_with_tma + def __getitem__(self, grid): self._init_handles() def runner(*args, stream=None): + args_expand = self.assemble_tensormap_to_arg(args) if stream is None: if self.device_type in ["cuda", "hip"]: stream = get_cuda_stream() else: stream = get_backend(self.device_type).get_stream(None) - self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, - CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args) + self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.clusterDims[0], + self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function, + CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand) return runner def get_sass(self, fun=None): diff --git a/python/triton/compiler/make_launcher.py b/python/triton/compiler/make_launcher.py index 73a7adbf3ead..297730723bb7 100644 --- a/python/triton/compiler/make_launcher.py +++ b/python/triton/compiler/make_launcher.py @@ -5,6 +5,7 @@ from ..common import _build from ..runtime.cache import get_cache_manager from ..runtime.jit import version_key +from .utils import generate_cu_signature def is_hip(): @@ -15,24 +16,26 @@ def is_hip(): # ----- stub -------- -def make_so_cache_key(version_hash, signature, constants): +def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): # Get unique key for the compiled code signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} - key = f"{version_hash}-{''.join(signature.values())}{constants}" + key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}" + for kw in kwargs: + key = f"{key}-{kwargs.get(kw)}" key = hashlib.md5(key.encode("utf-8")).hexdigest() return key -def make_stub(name, signature, constants): +def make_stub(name, signature, constants, ids, **kwargs): # name of files that are cached - so_cache_key = make_so_cache_key(version_key(), signature, constants) + so_cache_key = make_so_cache_key(version_key(), signature, constants, ids, **kwargs) so_cache_manager = get_cache_manager(so_cache_key) so_name = f"{name}.so" # retrieve stub from cache if it exists cache_path = so_cache_manager.get_file(so_name) if cache_path is None: with tempfile.TemporaryDirectory() as tmpdir: - src = generate_launcher(constants, signature) + src = generate_launcher(constants, signature, ids) src_path = os.path.join(tmpdir, "main.c") with open(src_path, "w") as f: f.write(src) @@ -64,7 +67,9 @@ def ty_to_cpp(ty): }[ty] -def generate_launcher(constants, signature): +def generate_launcher(constants, signature, ids): + start_desc = len(signature) + signature = generate_cu_signature(constants, signature, ids) arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): @@ -95,164 +100,157 @@ def format_of(ty): "int64_t": "L", }[ty] - format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiiiiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) # generate glue code if is_hip(): - src = f""" - #define __HIP_PLATFORM_AMD__ - #include - #include - #include - - static inline void gpuAssert(hipError_t code, const char *file, int line) - {{ - if (code != HIP_SUCCESS) - {{ - const char* prefix = "Triton Error [HIP]: "; - const char* str = hipGetErrorString(code); - char err[1024] = {{0}}; - snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str ); - PyErr_SetString(PyExc_RuntimeError, err); - }} - }} + folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']] + params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)] + src = f""" +#define __HIP_PLATFORM_AMD__ +#include +#include +#include +#include - #define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} +static inline void gpuAssert(hipError_t code, const char *file, int line) +{{ + if (code != HIP_SUCCESS) + {{ + const char* prefix = "Triton Error [HIP]: "; + const char* str = hipGetErrorString(code); + char err[1024] = {{0}}; + snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str ); + PyErr_SetString(PyExc_RuntimeError, err); + }} +}} - static int getWarpSize(hipStream_t stream) - {{ - int device_id = hipGetStreamDeviceId(stream); - gpuAssert(device_id >= 0 ? hipSuccess : hipErrorInvalidDevice, __FILE__, __LINE__); - hipDeviceProp_t prop; - HIP_CHECK(hipGetDeviceProperties(&prop, device_id)); - return prop.warpSize; - }} +#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} - static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, hipStream_t stream, hipFunction_t function, {arg_decls}) {{ - void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; - if (gridX*gridY*gridZ > 0) {{ - int warp_size = getWarpSize(stream); - HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, num_warps * warp_size, 1, 1, shared_memory, stream, params, 0)); - }} +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + // printf("_launch hip kernel\\n"); + void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }}; + if (gridX*gridY*gridZ > 0) {{ + HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, 64*num_warps, 1, 1, shared_memory, stream, params, 0)); }} + }} - typedef struct _DevicePtrInfo {{ - hipDeviceptr_t dev_ptr; - bool valid; - }} DevicePtrInfo; - - static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ - DevicePtrInfo ptr_info; - ptr_info.dev_ptr = 0; - ptr_info.valid = true; - - if (PyLong_Check(obj)) {{ - ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj); - return ptr_info; - }} - - if (obj == Py_None) {{ - // valid nullptr - return ptr_info; - }} - - PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); - - if (ptr) {{ - PyObject *empty_tuple = PyTuple_New(0); - PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(ptr); - - if (!PyLong_Check(ret)) {{ - PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); - ptr_info.valid = false; - return ptr_info; - }} - - ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret); - - if (!ptr_info.dev_ptr) - return ptr_info; - - uint64_t dev_ptr; - hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); - if (status == hipErrorInvalidValue) {{ - PyErr_Format(PyExc_ValueError, - "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); - ptr_info.valid = false; - }} - - ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr; - return ptr_info; - }} +typedef struct _DevicePtrInfo {{ + hipDeviceptr_t dev_ptr; + bool valid; +}} DevicePtrInfo; - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret); + if(!ptr_info.dev_ptr) return ptr_info; + uint64_t dev_ptr; + hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); + if (status == hipErrorInvalidValue) {{ + PyErr_Format(PyExc_ValueError, + "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); + ptr_info.valid = false; }} + ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr; + Py_DECREF(ret); + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + return ptr_info; +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + // printf("launch\\n"); + int gridX, gridY, gridZ; + uint64_t _stream; + uint64_t _function; + int num_warps; + int num_ctas; + int clusterDimX; + int clusterDimY; + int clusterDimZ; + int shared_memory; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *compiled_kernel = NULL; + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &num_ctas, &clusterDimX, &clusterDimY, &clusterDimZ, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{ + return NULL; + }} - static PyObject* launch(PyObject* self, PyObject* args) {{ + if (launch_enter_hook != Py_None) {{ + PyObject_CallObject(launch_enter_hook, args); + }} - int gridX, gridY, gridZ; - uint64_t _stream; - uint64_t _function; - int num_warps; - int shared_memory; - PyObject *launch_enter_hook = NULL; - PyObject *launch_exit_hook = NULL; - PyObject *compiled_kernel = NULL; - {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} - if (!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ - return NULL; - }} + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); - if (launch_enter_hook != Py_None) {{ - PyObject_CallObject(launch_enter_hook, args); - }} + if (launch_exit_hook != Py_None) {{ + PyObject_CallObject(launch_exit_hook, args); + }} - // raise exception asap - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - _launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items())}); - if (launch_exit_hook != Py_None) {{ - PyObject_CallObject(launch_exit_hook, args); - }} - if (PyErr_Occurred()) {{ - return NULL; - }} + if(PyErr_Occurred()) {{ + return NULL; + }} + // return None + Py_INCREF(Py_None); + return Py_None; +}} - // return None - Py_INCREF(Py_None); - return Py_None; - }} +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; - static PyMethodDef ModuleMethods[] = {{ - {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, - {{NULL, NULL, 0, NULL}} // sentinel - }}; - - static struct PyModuleDef ModuleDef = {{ - PyModuleDef_HEAD_INIT, - \"__triton_launcher\", - NULL, //documentation - -1, //size - ModuleMethods - }}; - - PyMODINIT_FUNC PyInit___triton_launcher(void) {{ - PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) {{ - return NULL; - }} - PyModule_AddFunctions(m, ModuleMethods); - return m; - }} - """ +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" else: + folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']] + params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)] src = f""" #include \"cuda.h\" #include #include +#include static inline void gpuAssert(CUresult code, const char *file, int line) {{ @@ -270,10 +268,57 @@ def format_of(ty): #define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} -static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{ - void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; - if(gridX*gridY*gridZ > 0){{ - CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, num_warps * 32, 1, 1, shared_memory, stream, params, 0)); +typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra); + +static cuLaunchKernelEx_t getLaunchKernelExHandle() {{ + // Open the shared library + void* handle = dlopen("libcuda.so", RTLD_LAZY); + if (!handle) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so"); + return NULL; + }} + // Clear any existing error + dlerror(); + cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx"); + // Check for errors + const char *dlsym_error = dlerror(); + if (dlsym_error) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so"); + return NULL; + }} + return cuLaunchKernelExHandle; +}} + +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }}; + if (gridX*gridY*gridZ > 0) {{ + if (num_ctas == 1) {{ + CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); + }} else {{ + CUlaunchAttribute launchAttr[2]; + launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launchAttr[0].value.clusterDim.x = clusterDimX; + launchAttr[0].value.clusterDim.y = clusterDimY; + launchAttr[0].value.clusterDim.z = clusterDimZ; + launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; + launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD; + CUlaunchConfig config; + config.gridDimX = gridX * clusterDimX; + config.gridDimY = gridY * clusterDimY; + config.gridDimZ = gridZ * clusterDimZ; + config.blockDimX = 32 * num_warps; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = shared_memory; + config.hStream = stream; + config.attrs = launchAttr; + config.numAttrs = 2; + static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL; + if (cuLaunchKernelExHandle == NULL) {{ + cuLaunchKernelExHandle = getLaunchKernelExHandle(); + }} + CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0)); + }} }} }} @@ -320,6 +365,7 @@ def format_of(ty): return ptr_info; }} PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; return ptr_info; }} @@ -328,31 +374,34 @@ def format_of(ty): uint64_t _stream; uint64_t _function; int num_warps; + int num_ctas; + int clusterDimX; + int clusterDimY; + int clusterDimZ; int shared_memory; PyObject *launch_enter_hook = NULL; PyObject *launch_exit_hook = NULL; PyObject *compiled_kernel = NULL; {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} - if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &num_ctas, &clusterDimX, &clusterDimY, &clusterDimZ, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{ return NULL; }} - if (launch_enter_hook != Py_None) {{ - PyObject_CallObject(launch_enter_hook, args); + if (launch_enter_hook != Py_None && !PyObject_CallObject(launch_enter_hook, args)) {{ + return NULL; }} // raise exception asap {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - _launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); - - if (launch_exit_hook != Py_None) {{ - PyObject_CallObject(launch_exit_hook, args); - }} + Py_BEGIN_ALLOW_THREADS; + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); + Py_END_ALLOW_THREADS; - if(PyErr_Occurred()) {{ + if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{ return NULL; }} + // return None Py_INCREF(Py_None); return Py_None; diff --git a/python/triton/compiler/utils.py b/python/triton/compiler/utils.py new file mode 100644 index 000000000000..cb4f1f3ab832 --- /dev/null +++ b/python/triton/compiler/utils.py @@ -0,0 +1,297 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from __future__ import annotations + +from ..runtime import driver + + +def generate_cu_signature(constants, signature, ids): + # CUtensorMap*s are always the last arguments + if ids["ids_of_tensormaps"] is not None: + signature = signature.copy() + num_signature = len(signature) + for i, _ in enumerate(ids["ids_of_tensormaps"]): + signature[num_signature + i] = '*CUtensorMap' + return signature + + +def dummy_tensormaps_info(n=2): + ret = [] + for i in range(n): + ret.append(InfoFromBackendForTensorMap(dummy=True)) + return ret + + +def parse_tma_info(infos, ids_of_folded_args): + ret = [] + for info in infos: + e = InfoFromBackendForTensorMap(infos=info) + e.ids_of_folded_args = ids_of_folded_args + ret.append(e) + return ret + + +def get_tma_mapping(tensormaps_info): + ret = {} + if tensormaps_info is not None: + for i, e in enumerate(tensormaps_info): + ret.update(e.get_address_tma_mapping()) + else: + ret = None + return ret + + +def get_ids_of_tensormaps(tensormaps_info): + ret = None + # order is not relevant + if tensormaps_info is not None: + ret = [e.get_id_of_tensormap() for e in tensormaps_info] + return ret + + +# decouple information for tensormap from backend +# please ignore the naming style, xx_yy is compiler.py style, xxYy is to comply with cuda tensormap style +# mixing style is for readability +class InfoFromBackendForTensorMap: + N = 2 + n = 0 + ntma = 0 + + def __init__(self, infos=None, dummy=False): + self.dummy = dummy + self.ids_of_folded_args = () + if not dummy and not isinstance(infos, dict): + self._extract_info_from_backend(infos) + elif not dummy and isinstance(infos, dict): + self._extract_info_from_dict(infos) + elif dummy: + self._dummy() + + def _dummy(self): + assert InfoFromBackendForTensorMap.n < InfoFromBackendForTensorMap.N + if InfoFromBackendForTensorMap.n == 0: + self.tensorDataType = driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"] + self.tensorRank = 4 + self.globalAddressArgIdx = 0 + self.globalStridesArgIdx = [7, 6, -1, -1] + self.globalDimsArgIdx = [5, 3, -1, -1] + self.boxDims = [16, 64, 1, 1] + self.elementStrides = [1, 1, 1, 1] + self.interleave = driver.utils.CUtensorMapInterleave["CU_TENSOR_MAP_INTERLEAVE_NONE"] + self.swizzle = driver.utils.CUtensorMapSwizzle["CU_TENSOR_MAP_SWIZZLE_32B"] + self.l2Promotion = driver.utils.CUtensorMapL2promotion["CU_TENSOR_MAP_L2_PROMOTION_L2_128B"] + self.TMADescArgIdx = 11 + self.oobFill = driver.utils.CUtensorMapFloatOOBfill["CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE"] + InfoFromBackendForTensorMap.n += 1 + return + if InfoFromBackendForTensorMap.n == 1: + self.tensorDataType = driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"] + self.tensorRank = 4 + self.globalAddressArgIdx = 1 + self.globalStridesArgIdx = [7, 6, -1, -1] + self.globalDimsArgIdx = [5, 3, -1, -1] + self.boxDims = [16, 64, 1, 1] + self.elementStrides = [1, 1, 1, 1] + self.interleave = driver.utils.CUtensorMapInterleave["CU_TENSOR_MAP_INTERLEAVE_NONE"] + self.swizzle = driver.utils.CUtensorMapSwizzle["CU_TENSOR_MAP_SWIZZLE_32B"] + self.l2Promotion = driver.utils.CUtensorMapL2promotion["CU_TENSOR_MAP_L2_PROMOTION_L2_128B"] + self.TMADescArgIdx = 12 + self.oobFill = driver.utils.CUtensorMapFloatOOBfill["CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE"] + InfoFromBackendForTensorMap.n += 1 + return + + def _extract_info_from_backend(self, infos): + self.tensorDataType = infos.tensorDataType + self.tensorRank = infos.tensorRank + self.globalAddressArgIdx = infos.globalAddressArgIdx + self.globalStridesArgIdx = infos.globalStridesArgIdx + self.globalDimsArgIdx = infos.globalDimsArgIdx + self.boxDims = infos.boxDims + self.elementStrides = infos.elementStrides + self.interleave = infos.interleave + self.swizzle = infos.swizzle + self.l2Promotion = infos.l2Promotion + self.oobFill = infos.oobFill + self.TMADescArgIdx = infos.TMADescArgIdx + + # dict could be from cached metadata json + def _extract_info_from_dict(self, infos: dict): + self.tensorDataType = infos['tensorDataType'] + self.tensorRank = infos['tensorRank'] + self.globalAddressArgIdx = infos['globalAddressArgIdx'] + self.globalStridesArgIdx = infos['globalStridesArgIdx'] + self.globalDimsArgIdx = infos['globalDimsArgIdx'] + self.boxDims = infos['boxDims'] + self.elementStrides = infos['elementStrides'] + self.interleave = infos['interleave'] + self.swizzle = infos['swizzle'] + self.l2Promotion = infos['l2Promotion'] + self.oobFill = infos['oobFill'] + self.TMADescArgIdx = infos['TMADescArgIdx'] + + def get_address_tma_mapping(self): + return {self.globalAddressArgIdx: self.TMADescArgIdx + len(self.ids_of_folded_args)} + + def get_id_of_tensormap(self): + return self.TMADescArgIdx + len(self.ids_of_folded_args) + + def getTMADescArgIdx(self): + return self.TMADescArgIdx + + # dtype:cuda.CUtensorMapDataType | int + def bytes_from_type(self, dtype): + return {driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4}[dtype] + + def getTensorMapDataType(self): + return self.tensorDataType + + def getInterleave(self): + return self.interleave + + def getSwizzle(self): + return self.swizzle + + def getL2Promotion(self): + return self.l2Promotion + + def getOobFill(self): + return self.oobFill + + def getTensorRank(self): + return self.tensorRank + + def getBoxDims(self): + return self.boxDims + + def getElementStrides(self): + return self.elementStrides + + def getGlobalAddress(self, args): + idx = self.getOriginArgIdx(self.globalAddressArgIdx, args) + return args[idx] + + # args, captured kernel args in runtime + def getGlobalDims(self, args): + shape = [] + for e in self.globalDimsArgIdx: + t = 1 + # < 0 means folded arg or constant (-1 - value) + # -1 means extended dim which is 1, -2 means folded arg with constant 1 (-1 - value) + if e == -1: + t = 1 + elif e < 0 and e != -1: + t = -e - 1 + else: + idx = self.getOriginArgIdx(e, args) + t = args[idx] + shape.append(t) + return shape + + def getGlobalStrides(self, args): + t_globalDims = [int(e) for e in self.getGlobalDims(args)] + t_globalStridesArgIdx = self.globalStridesArgIdx.copy() + strides_in_elements = [] + # todo: get all stride from backend even in extended mode + for i in range(self.tensorRank): + t = 1 + if t_globalStridesArgIdx[i] == -1: + for ii in range(i): + t *= t_globalDims[ii] + # -2 means the sride in arguments is folded constant 1, we don't use 1 because it can not be distinguished from index 1 + elif t_globalStridesArgIdx[i] == -2: + t = 1 + else: + new_idx = self.getOriginArgIdx(t_globalStridesArgIdx[i], args) + t = args[new_idx] + + strides_in_elements.append(t) + + strides_in_elements = strides_in_elements[1:] + strides_in_bytes = [e * self.bytes_from_type(self.tensorDataType) for e in strides_in_elements] + return strides_in_bytes + + def getOriginArgIdx(self, idx, args): + if self.ids_of_folded_args: + ids_before_folding_arg = [i for i in range(len(args)) if i not in self.ids_of_folded_args] + return ids_before_folding_arg[idx] + else: + return idx + + def tensormap(self, args): + return driver.utils.cuTensorMapEncodeTiled( + self.getTensorMapDataType(), + self.getTensorRank(), + self.getGlobalAddress(args), + self.getGlobalDims(args), + self.getGlobalStrides(args), + self.getBoxDims(), + self.getElementStrides(), + self.getInterleave(), + self.getSwizzle(), + self.getL2Promotion(), + self.getOobFill() + ) + + # make hashable to use as partial key in cache + def __hash__(self): + return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx), tuple(self.globalStridesArgIdx), self.tensorDataType, + self.tensorRank, tuple(self.boxDims), tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill)) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx, self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle, self.l2Promotion, self.oobFill) == ( + other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx, other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims, other.elementStrides, other.interleave, other.swizzle, other.l2Promotion, other.oobFill) + + +class TensorMapManager: + def __init__(self): + self.tensormaps_device = {} + + def __getitem__(self, key: tuple): + if key in self.tensormaps_device: + return int(self.tensormaps_device[key]) + else: + (e, args) = key + t_tensormap = e.tensormap(args) + TENSORMAP_SIZE_IN_BYTES = 128 + t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES) + driver.utils.cuMemcpyHtoD( + t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES) + self.tensormaps_device[key] = t_tensormap_device + return int(self.tensormaps_device[key]) + + def __del__(self): + for _, v in self.tensormaps_device.items(): + driver.utils.cuMemFree(v) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 32fc40d1d765..6b719c684cfd 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -4,11 +4,21 @@ from . import math from . import extra from .standard import ( + argmax, + argmin, cdiv, + cumprod, + cumsum, + max, + maximum, + min, + minimum, sigmoid, softmax, + sum, ravel, swizzle2d, + xor_sum, zeros, zeros_like, ) @@ -17,8 +27,6 @@ abs, advance, arange, - argmin, - argmax, associative_scan, atomic_add, atomic_and, @@ -35,8 +43,6 @@ cat, constexpr, cos, - cumprod, - cumsum, debug_barrier, device_assert, device_print, @@ -50,9 +56,11 @@ float32, float64, float8e4b15, - float8e4, + float8e4b15x4, + float8e4nv, float8e5, function_type, + inline_asm_elementwise, int1, int16, int32, @@ -61,12 +69,8 @@ load, log, make_block_ptr, - max, max_constancy, max_contiguous, - maximum, - min, - minimum, multiple_of, num_programs, pi32_t, @@ -79,7 +83,6 @@ static_assert, static_print, store, - sum, static_range, tensor, trans, @@ -92,7 +95,6 @@ view, void, where, - xor_sum, ) from .random import ( pair_uniform_to_normal, @@ -148,10 +150,12 @@ "float32", "float64", "float8e4b15", - "float8e4", + "float8e4b15x4", + "float8e4nv", "float8e5", "full", "function_type", + "inline_asm_elementwise", "int1", "int16", "int32", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 48f118d21de0..0c73c1f8bdf9 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -6,8 +6,7 @@ from typing import Callable, List, Sequence, TypeVar from .._C.libtriton.triton import ir -from ..runtime.jit import jit -from . import math, semantic +from . import semantic T = TypeVar('T') @@ -76,7 +75,7 @@ def _to_tensor(x, builder): class dtype: SINT_TYPES = ['int8', 'int16', 'int32', 'int64'] UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64'] - FP_TYPES = ['fp8e4b15', 'fp8e4', 'fp8e5', 'fp16', 'bf16', 'fp32', 'fp64'] + FP_TYPES = ['fp8e4b15', 'fp8e4b15x4', 'fp8e4nv', 'fp8e5', 'fp16', 'bf16', 'fp32', 'fp64'] STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64'] OTHER_TYPES = ['void'] @@ -100,7 +99,11 @@ def __init__(self, name): self.fp_mantissa_width = 3 self.primitive_bitwidth = 8 self.exponent_bias = 15 - elif name == 'fp8e4': + elif name == 'fp8e4b15x4': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 15 + elif name == 'fp8e4nv': self.fp_mantissa_width = 3 self.primitive_bitwidth = 8 self.exponent_bias = 7 @@ -132,12 +135,18 @@ def __init__(self, name): def is_fp8(self): return 'fp8' in self.name - def is_fp8e4(self): - return self.name == 'fp8e4' + def is_fp8e4nv(self): + return self.name == 'fp8e4nv' def is_fp8e4b15(self): return self.name == 'fp8e4b15' + def is_fp8e4b15x4(self): + return self.name == 'fp8e4b15x4' + + def is_fp8e5(self): + return self.name == 'fp8e5' + def is_fp16(self): return self.name == 'fp16' @@ -195,6 +204,10 @@ def is_int(self): def is_bool(self): return self.is_int1() + @staticmethod + def is_dtype(type_str): + return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES + @staticmethod def is_void(): raise RuntimeError("Not implemented") @@ -237,10 +250,12 @@ def to_ir(self, builder: ir.builder) -> ir.type: return builder.get_int64_ty() elif self.name == 'fp8e5': return builder.get_fp8e5_ty() - elif self.name == 'fp8e4': - return builder.get_fp8e4_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() elif self.name == 'fp8e4b15': return builder.get_fp8e4b15_ty() + elif self.name == 'fp8e4b15x4': + return builder.get_fp8e4b15x4_ty() elif self.name == 'fp16': return builder.get_half_ty() elif self.name == 'bf16': @@ -373,8 +388,9 @@ def to_ir(self, builder: ir.builder): uint32 = dtype('uint32') uint64 = dtype('uint64') float8e5 = dtype('fp8e5') -float8e4 = dtype('fp8e4') +float8e4nv = dtype('fp8e4nv') float8e4b15 = dtype('fp8e4b15') +float8e4b15x4 = dtype('fp8e4b15x4') float16 = dtype('fp16') bfloat16 = dtype('bf16') float32 = dtype('fp32') @@ -1392,180 +1408,6 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None return rvalue, rindices -@jit -def minimum(x, y): - """ - Computes the element-wise minimum of :code:`x` and :code:`y`. - - :param input: the first input tensor - :type input: Block - :param other: the second input tensor - :type other: Block - """ - return where(x < y, x, y) - - -@jit -def maximum(x, y): - """ - Computes the element-wise maximum of :code:`x` and :code:`y`. - - :param input: the first input tensor - :type input: Block - :param other: the second input tensor - :type other: Block - """ - return where(x > y, x, y) - -# max and argmax - - -@jit -def _argmax_combine(value1, index1, value2, index2, tie_break_left): - if tie_break_left: - tie = value1 == value2 and index1 < index2 - else: - tie = False - gt = value1 > value2 or tie - v_ret = where(gt, value1, value2) - i_ret = where(gt, index1, index2) - return v_ret, i_ret - - -@jit -def _argmax_combine_tie_break_left(value1, index1, value2, index2): - return _argmax_combine(value1, index1, value2, index2, True) - - -@jit -def _argmax_combine_tie_break_fast(value1, index1, value2, index2): - return _argmax_combine(value1, index1, value2, index2, False) - - -@jit -def _fast_max(x, y): - return math.max(x, y) - - -@jit -@_add_reduction_docstr("maximum", - return_indices_arg="return_indices", - tie_break_arg="return_indices_tie_break_left") -def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True): - input = _promote_reduction_input(input) - if return_indices: - if return_indices_tie_break_left: - return _reduce_with_indices(input, axis, _argmax_combine_tie_break_left) - else: - return _reduce_with_indices(input, axis, _argmax_combine_tie_break_fast) - else: - if constexpr(input.dtype.primitive_bitwidth) < 32: - if constexpr(input.dtype.is_floating()): - input = input.to(float32) - else: - assert input.dtype.is_integer_type() - input = input.to(int32) - return reduce(input, axis, _fast_max) - - -@jit -@_add_reduction_docstr("maximum index", tie_break_arg="tie_break_left") -def argmax(input, axis, tie_break_left=True): - (_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left) - return ret - -# min and argmin - - -@jit -def _argmin_combine(value1, index1, value2, index2, tie_break_left): - if tie_break_left: - tie = value1 == value2 and index1 < index2 - else: - tie = False - lt = value1 < value2 or tie - value_ret = where(lt, value1, value2) - index_ret = where(lt, index1, index2) - return value_ret, index_ret - - -@jit -def _argmin_combine_tie_break_left(value1, index1, value2, index2): - return _argmin_combine(value1, index1, value2, index2, True) - - -@jit -def _argmin_combine_tie_break_fast(value1, index1, value2, index2): - return _argmin_combine(value1, index1, value2, index2, False) - - -@jit -def _fast_min(x, y): - return math.min(x, y) - - -@jit -@_add_reduction_docstr("minimum", - return_indices_arg="return_indices", - tie_break_arg="return_indices_tie_break_left") -def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True): - input = _promote_reduction_input(input) - if return_indices: - if return_indices_tie_break_left: - return _reduce_with_indices(input, axis, _argmin_combine_tie_break_left) - else: - return _reduce_with_indices(input, axis, _argmin_combine_tie_break_fast) - else: - if constexpr(input.dtype.primitive_bitwidth) < 32: - if constexpr(input.dtype.is_floating()): - input = input.to(float32) - else: - assert input.dtype.is_integer_type() - input = input.to(int32) - return reduce(input, axis, _fast_min) - - -@jit -@_add_reduction_docstr("minimum index", - tie_break_arg="tie_break_left") -def argmin(input, axis, tie_break_left=True): - _, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left) - return ret - - -@jit -def _sum_combine(a, b): - return a + b - -# sum - - -@jit -@_add_reduction_docstr("sum") -def sum(input, axis=None): - input = _promote_reduction_input(input) - return reduce(input, axis, _sum_combine) - - -@jit -def _xor_combine(a, b): - return a ^ b - - -# xor sum - -@builtin -@_add_reduction_docstr("xor sum") -def xor_sum(input, axis=None, _builder=None, _generator=None): - scalar_ty = input.type.scalar - if not scalar_ty.is_int(): - raise ValueError("xor_sum only supported for integers") - - input = _promote_reduction_input(input, _builder=_builder) - return reduce(input, axis, _xor_combine, - _builder=_builder, _generator=_generator) - - # ----------------------- # Scans # ----------------------- @@ -1616,31 +1458,6 @@ def make_combine_region(scan_op): axis = _constexpr_to_value(axis) return semantic.associative_scan(input, axis, make_combine_region, _builder) -# cumsum - - -@jit -@_add_scan_docstr("cumsum") -def cumsum(input, axis=0): - # todo rename this to a generic function name - input = _promote_reduction_input(input) - return associative_scan(input, axis, _sum_combine) - -# cumprod - - -@jit -def _prod_combine(a, b): - return a * b - - -@jit -@_add_scan_docstr("cumprod") -def cumprod(input, axis=0): - # todo rename this to a generic function name - input = _promote_reduction_input(input) - return associative_scan(input, axis, _prod_combine) - # ----------------------- # Compiler Hint Ops # ----------------------- @@ -1813,6 +1630,47 @@ def device_assert(cond, msg="", _builder=None): return semantic.device_assert(_to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder) +@builtin +def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pure: bool, pack: int, _builder=None): + ''' + Execute the inline assembly to a packed of elements of the tensor + :param asm: assembly to be inlined, it has to match the target assembly format + :param constraints: string representing the mapping of operands to register + :param args: the arguments of the operation + :param dtype: the element type of the returned variable + :param is_pure: whether the operation is pure + :param pack: the number of elements to be processed by one instance of inline assembly + :param _builder: the builder + :return: the return value of the function + ''' + dispatch_args = args.copy() + asm = _constexpr_to_value(asm) + constraints = _constexpr_to_value(constraints) + pack = _constexpr_to_value(pack) + is_pure = _constexpr_to_value(is_pure) + ret_shape = None + arg_types = [] + res_ty = dtype + for i in range(len(dispatch_args)): + dispatch_args[i] = _to_tensor(dispatch_args[i], _builder) + arg_types.append(dispatch_args[i].dtype) + if len(arg_types) > 0: + arg_types = tuple(arg_types) + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for i, item in enumerate(dispatch_args): + _, broadcast_arg = semantic.binary_op_type_checking_impl( + item, broadcast_arg, _builder, arithmetic_check=False) + # Change the shape of each argument based on the broadcast shape + for i in range(len(dispatch_args)): + dispatch_args[i], _ = semantic.binary_op_type_checking_impl( + dispatch_args[i], broadcast_arg, _builder, arithmetic_check=False) + ret_shape = broadcast_arg.shape + res_ty = block_type(dtype, ret_shape) + call = _builder.create_inline_asm(asm, constraints, [t.handle for t in args], res_ty.to_ir(_builder), is_pure, pack) + return tensor(call, res_ty) + + # ----------------------- # Iterators # ----------------------- @@ -1946,6 +1804,16 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _builder) +def binary_op_type_legalization(lhs, rhs, builder): + ''' + Convert both operands to a single common type + :param lhs: the left operand + :param rhs: the right operand + :param builder: the builder + ''' + return semantic.binary_op_type_checking_impl(lhs, rhs, builder) + + def extern(fn): """A decorator for external functions.""" return builtin(fn) diff --git a/python/triton/language/cuda2gcn.bc b/python/triton/language/cuda2gcn.bc deleted file mode 100755 index fadc5b4c11c5..000000000000 Binary files a/python/triton/language/cuda2gcn.bc and /dev/null differ diff --git a/python/triton/language/cuda2gcn.patch b/python/triton/language/cuda2gcn.patch deleted file mode 100644 index 3d8f633cb169..000000000000 --- a/python/triton/language/cuda2gcn.patch +++ /dev/null @@ -1,24 +0,0 @@ -diff --git a/CMakeLists.txt b/CMakeLists.txt -index b65f1b5..19cc5a9 100644 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -62,12 +62,13 @@ include(OCL) - set(AMDGCN_LIB_LIST) - set(AMDGCN_DEP_LIST) - add_subdirectory(irif) --add_subdirectory(oclc) --add_subdirectory(ocml) --add_subdirectory(ockl) --add_subdirectory(opencl) --add_subdirectory(hip) --add_subdirectory(asanrtl) -+#add_subdirectory(oclc) -+#add_subdirectory(ocml) -+#add_subdirectory(ockl) -+#add_subdirectory(opencl) -+#add_subdirectory(hip) -+#add_subdirectory(asanrtl) -+add_subdirectory(cuda2gcn) - - enable_testing() - add_subdirectory(test/compile) diff --git a/python/triton/language/extra/cuda.bc b/python/triton/language/extra/cuda.bc deleted file mode 100644 index 4538ac35446a..000000000000 Binary files a/python/triton/language/extra/cuda.bc and /dev/null differ diff --git a/python/triton/language/extra/cuda.py b/python/triton/language/extra/cuda.py index 92df37a67c77..d69120938185 100644 --- a/python/triton/language/extra/cuda.py +++ b/python/triton/language/extra/cuda.py @@ -1,19 +1,15 @@ -import os - from .. import core -__path__ = os.path.dirname(os.path.abspath(__file__)) - @core.extern def globaltimer(_builder=None): - return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [], - {tuple(): ("globaltimer", core.dtype("int64")), - }, is_pure=False, _builder=_builder) + return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], + dtype=core.int64, is_pure=False, + pack=1, _builder=_builder) @core.extern def smid(_builder=None): - return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [], - {tuple(): ("smid", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], + dtype=core.int32, is_pure=True, + pack=1, _builder=_builder) diff --git a/python/triton/language/make_cuda2gcn.sh b/python/triton/language/make_cuda2gcn.sh deleted file mode 100755 index a3fc04cb7476..000000000000 --- a/python/triton/language/make_cuda2gcn.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash - -set -e - -pushd . - -git clone https://github.com/dfukalov/ROCm-Device-Libs.git -cd ROCm-Device-Libs -git apply ../cuda2gcn.patch -mkdir build -cd build -cmake .. -DCMAKE_PREFIX_PATH=$HOME/.triton/llvm/clang+llvm-14.0.0-x86_64-linux-gnu-ubuntu-18.04 -make -j4 - -popd -cp ROCm-Device-Libs/build/amdgcn/bitcode/cuda2gcn.bc . -rm -rf ROCm-Device-Libs diff --git a/python/triton/language/math.py b/python/triton/language/math.py index 3fc01bf56bbd..0485de259fd0 100644 --- a/python/triton/language/math.py +++ b/python/triton/language/math.py @@ -40,26 +40,34 @@ def byte_perm(arg0, arg1, arg2, _builder=None): @core.extern def min(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("int32"), core.dtype("int32"),): ("__nv_min", core.dtype("int32")), - (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umin", core.dtype("uint32")), - (core.dtype("int64"), core.dtype("int64"),): ("__nv_llmin", core.dtype("int64")), - (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_ullmin", core.dtype("uint64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fminf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmin", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + arg0 = core._to_tensor(arg0, _builder) + arg1 = core._to_tensor(arg1, _builder) + arg0, arg1 = core.binary_op_type_legalization(arg0, arg1, _builder) + dtype = arg0.dtype + if dtype.is_floating(): + return core.tensor(_builder.create_minf(arg0.handle, arg1.handle), arg0.type) + elif dtype.is_int_signed(): + return core.tensor(_builder.create_minsi(arg0.handle, arg1.handle), arg0.type) + elif dtype.is_int_unsigned(): + return core.tensor(_builder.create_minui(arg0.handle, arg1.handle), arg0.dtype) + else: + assert False, f"Unexpected dtype {dtype}" @core.extern def max(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("int32"), core.dtype("int32"),): ("__nv_max", core.dtype("int32")), - (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umax", core.dtype("uint32")), - (core.dtype("int64"), core.dtype("int64"),): ("__nv_llmax", core.dtype("int64")), - (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_ullmax", core.dtype("uint64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaxf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmax", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + arg0 = core._to_tensor(arg0, _builder) + arg1 = core._to_tensor(arg1, _builder) + arg0, arg1 = core.binary_op_type_legalization(arg0, arg1, _builder) + dtype = arg0.dtype + if dtype.is_floating(): + return core.tensor(_builder.create_maxf(arg0.handle, arg1.handle), arg0.type) + elif dtype.is_int_signed(): + return core.tensor(_builder.create_maxsi(arg0.handle, arg1.handle), arg0.type) + elif dtype.is_int_unsigned(): + return core.tensor(_builder.create_maxui(arg0.handle, arg1.handle), arg0.dtype) + else: + assert False, f"Unexpected dtype {dtype}" @core.extern diff --git a/python/triton/language/random.py b/python/triton/language/random.py index ed1993e33c40..7af60855b040 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -1,5 +1,6 @@ from ..runtime.jit import jit from . import core as tl +from . import standard PHILOX_KEY_A: tl.constexpr = 0x9E3779B9 PHILOX_KEY_B: tl.constexpr = 0xBB67AE85 @@ -55,7 +56,7 @@ def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): using `randint4x` is likely to be faster than calling `randint` 4 times. :param seed: The seed for generating random numbers. - :param offsets: The offsets to generate random numbers for. + :param offset: The offsets to generate random numbers for. """ ret, _, _, _ = randint4x(seed, offset, n_rounds) return ret @@ -120,7 +121,7 @@ def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ Given a :code:`seed` scalar and an :code:`offsets` block, - returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)`. + returns 4 blocks of random :code:`float32` in :math:`U(0, 1)`. :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. @@ -141,7 +142,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): @jit def pair_uniform_to_normal(u1, u2): """Box-Muller transform""" - u1 = tl.maximum(1.0e-7, u1) + u1 = standard.maximum(1.0e-7, u1) th = 6.283185307179586 * u2 r = tl.sqrt(-2.0 * tl.log(u1)) return r * tl.cos(th), r * tl.sin(th) @@ -167,7 +168,7 @@ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ Given a :code:`seed` scalar and an :code:`offset` block, - returns a 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + returns 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 94067d8efc9a..43382636b4a7 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1,6 +1,5 @@ from __future__ import annotations # remove after python 3.11 -import warnings from functools import wraps from typing import List, Optional, Sequence, Tuple, TypeVar @@ -136,6 +135,8 @@ def add(input: tl.tensor, # ptr + offset if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): input, other = other, input + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar if input_scalar_ty.is_ptr(): return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type) # float + float @@ -195,7 +196,7 @@ def truediv(input: tl.tensor, elif input_scalar_ty.is_int() and other_scalar_ty.is_int(): input = cast(input, tl.float32, builder) other = cast(other, tl.float32, builder) - # float / float (cast to highest exponent type) + # float / float (cast to the highest exponent type) elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating(): if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width: other = cast(other, input_scalar_ty, builder) @@ -692,9 +693,8 @@ def cast(input: tl.tensor, dst_sca_ty = dst_ty.scalar if _is_cuda(builder.arch) and builder.arch < 89 and \ - (src_sca_ty.is_fp8e4() or dst_sca_ty.is_fp8e4()): - warnings.warn("Standard tl.float8e4 format will be deprecated on SM < 89. " - "Please use tl.float8e4b15.", DeprecationWarning) + (src_sca_ty.is_fp8e4nv() or dst_sca_ty.is_fp8e4nv()): + assert False, "fp8e4nv data type is not supported on CUDA arch < 89" # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ @@ -988,8 +988,8 @@ def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builde if not val.type.is_block(): val = broadcast_impl_shape(val, block_shape, builder) assert val.type.is_block(), "Value argument must be block type or a scalar" - assert block_shape == val.type.get_block_shapes(), "Block shape and value shape mismatch" - assert ptr.type.element_ty.element_ty == val.type.element_ty, "Block element type and value element type mismatch" + assert block_shape == val.type.get_block_shapes(), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch" + assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch" elt_ty = ptr.type.element_ty.element_ty assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`" @@ -1274,13 +1274,18 @@ def gpu_matrix_core_version() -> int: gpu_name = gfx_arch_details[1].split(':')[0] if gpu_name in ['gfx908']: return 1 - if gpu_name in ['gfx90a', 'gfx940', 'gfx941']: + if gpu_name in ['gfx90a', 'gfx940', 'gfx941', 'gfx942']: return 2 return 0 def mfma_supported_granularity(m, n, k) -> bool: granularity_mn = 32 granularity_k = 8 + import os + if "MFMA_TYPE" in os.environ and os.environ["MFMA_TYPE"] == "16": + granularity_mn = 16 + granularity_k = 16 + if m % granularity_mn != 0 or n % granularity_mn != 0: return False if k % granularity_k != 0: @@ -1300,18 +1305,59 @@ def dot(lhs: tl.tensor, allow_tf32: bool, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + def assert_dtypes_valid(lhs_dtype, rhs_dtype, arch): + if is_hip(): + assert lhs.dtype == rhs.dtype or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp16()) or \ + (lhs.type.scalar.is_fp16() and rhs.type.scalar.is_fp8()), \ + f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!" + + return + + # Checks for non-cuda archs + if _is_cuda(builder.arch): + # Checks for cuda arch + if arch < 90: + assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90" + assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + else: + assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90" + assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4(), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90" + if lhs_dtype.is_int() or rhs_dtype.is_int(): + assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})" + assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})" + elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8(): + assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})" + assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})" + else: + assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(), f"Unsupported dtype {lhs_dtype}" + assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(), f"Unsupported dtype {rhs_dtype}" + assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + return + + assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + return + assert lhs.type.is_block() and rhs.type.is_block() - assert lhs.dtype == rhs.dtype, f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!" + assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.arch) + assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!" assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!" assert lhs.shape[1].value == rhs.shape[0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})" assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \ - and rhs.shape[1].value >= 16,\ + and rhs.shape[1].value >= 16, \ f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!" + + # hip for now converts fp8 to fp16 for mixed input + if is_hip(): + if lhs.type.scalar.is_fp8(): + lhs = cast(lhs, tl.float16, builder) + elif rhs.type.scalar.is_fp8(): + rhs = cast(rhs, tl.float16, builder) + if lhs.type.scalar.is_int(): assert lhs.type.scalar == tl.int8, "only int8 supported!" # TODO: This is CUDA specific, check if ROCm has the same limitation - assert lhs.shape[1].value >= 32, "small blocks not supported!" + assert is_hip() or lhs.shape[1].value >= 32, "small blocks not supported!" _0 = builder.get_int32(0) ret_scalar_ty = tl.int32 elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): @@ -1446,7 +1492,7 @@ def wrap_tensor(x, scalar_ty): def _check_dtype(dtypes: List[str]) -> T: """ - We following libdevice's convention to check accepted data types for math functions. + We're following libdevice's convention to check accepted data types for math functions. It is not a good practice to support all data types as accelerators/GPUs don't support many float16 and bfloat16 math operations. We should let the users know that they are using and invoke explicit cast to convert diff --git a/python/triton/language/standard.py b/python/triton/language/standard.py index b997674c91b6..8acc4261585f 100644 --- a/python/triton/language/standard.py +++ b/python/triton/language/standard.py @@ -1,7 +1,7 @@ from __future__ import annotations from ..runtime.jit import jit -from . import core +from . import core, math # ----------------------- # Standard library @@ -14,7 +14,7 @@ def cdiv(x, div): Computes the ceiling division of :code:`x` by :code:`div` :param x: the input number - :type input: Block + :type x: Block :param div: the divisor :param div: Block """ @@ -30,9 +30,9 @@ def sigmoid(x): @jit @core._add_math_1arg_docstr("softmax") def softmax(x, ieee_rounding=False): - z = x - core.max(x, 0) + z = x - max(x, 0) num = core.exp(z) - den = core.sum(num, 0) + den = sum(num, 0) return core.fdiv(num, den, ieee_rounding) @@ -73,7 +73,7 @@ def swizzle2d(i, j, size_i, size_j, size_g): # row-index of the first element of this group off_i = group_id * size_g # last group may have fewer rows - size_g = core.minimum(size_i - off_i, size_g) + size_g = minimum(size_i - off_i, size_g) # new row and column indices new_i = off_i + (ij % size_g) new_j = (ij % size_gj) // size_g @@ -96,3 +96,192 @@ def zeros(shape, dtype): @jit def zeros_like(input): return zeros(input.shape, input.dtype) + + +@jit +def minimum(x, y): + """ + Computes the element-wise minimum of :code:`x` and :code:`y`. + + :param input: the first input tensor + :type input: Block + :param other: the second input tensor + :type other: Block + """ + return math.min(x, y) + + +@jit +def maximum(x, y): + """ + Computes the element-wise maximum of :code:`x` and :code:`y`. + + :param input: the first input tensor + :type input: Block + :param other: the second input tensor + :type other: Block + """ + return math.max(x, y) + +# max and argmax + + +@jit +def _argmax_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + gt = value1 > value2 or tie + v_ret = core.where(gt, value1, value2) + i_ret = core.where(gt, index1, index2) + return v_ret, i_ret + + +@jit +def _argmax_combine_tie_break_left(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, True) + + +@jit +def _argmax_combine_tie_break_fast(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, False) + + +@jit +@core._add_reduction_docstr("maximum", + return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True): + input = core._promote_reduction_input(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left) + else: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < 32: + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_integer_type() + input = input.to(core.int32) + return core.reduce(input, axis, maximum) + + +@jit +@core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left") +def argmax(input, axis, tie_break_left=True): + (_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left) + return ret + +# min and argmin + + +@jit +def _argmin_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + lt = value1 < value2 or tie + value_ret = core.where(lt, value1, value2) + index_ret = core.where(lt, index1, index2) + return value_ret, index_ret + + +@jit +def _argmin_combine_tie_break_left(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, True) + + +@jit +def _argmin_combine_tie_break_fast(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, False) + + +@jit +@core._add_reduction_docstr("minimum", + return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True): + input = core._promote_reduction_input(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left) + else: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < 32: + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_integer_type() + input = input.to(core.int32) + return core.reduce(input, axis, minimum) + + +@jit +@core._add_reduction_docstr("minimum index", + tie_break_arg="tie_break_left") +def argmin(input, axis, tie_break_left=True): + _, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left) + return ret + + +@jit +def _sum_combine(a, b): + return a + b + +# sum + + +@jit +@core._add_reduction_docstr("sum") +def sum(input, axis=None): + input = core._promote_reduction_input(input) + return core.reduce(input, axis, _sum_combine) + + +@jit +def _xor_combine(a, b): + return a ^ b + +# xor sum + + +@core.builtin +@core._add_reduction_docstr("xor sum") +def xor_sum(input, axis=None, _builder=None, _generator=None): + scalar_ty = input.type.scalar + if not scalar_ty.is_int(): + raise ValueError("xor_sum only supported for integers") + + input = core._promote_reduction_input(input, _builder=_builder) + return core.reduce(input, axis, _xor_combine, + _builder=_builder, _generator=_generator) + +# cumsum + + +@jit +@core._add_scan_docstr("cumsum") +def cumsum(input, axis=0): + # todo rename this to a generic function name + input = core._promote_reduction_input(input) + return core.associative_scan(input, axis, _sum_combine) + +# cumprod + + +@jit +def _prod_combine(a, b): + return a * b + + +@jit +@core._add_scan_docstr("cumprod") +def cumprod(input, axis=0): + # todo rename this to a generic function name + input = core._promote_reduction_input(input) + return core.associative_scan(input, axis, _prod_combine) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 5366302068a6..eaf4f2f40dee 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -395,8 +395,8 @@ def backward(ctx, dc): a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width, ) dout = dc if ctx.has_out else None - return da, db, None, None, None,\ - None, None, None, None,\ + return da, db, None, None, None, \ + None, None, None, None, \ None, None, None, None, None, dout diff --git a/python/triton/ops/flash_attention.py b/python/triton/ops/flash_attention.py index b35b0fdb506d..a946fcf64404 100644 --- a/python/triton/ops/flash_attention.py +++ b/python/triton/ops/flash_attention.py @@ -346,7 +346,7 @@ def backward(ctx, do): dk = torch.empty_like(k) dv = torch.empty_like(v) delta = torch.empty_like(L) - _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( + _bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )]( o, do, delta, BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 2c57ccd62074..e7c27ec40d51 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -81,8 +81,9 @@ def _kernel(A, B, C, M, N, K, stride_bk, stride_bn, stride_cm, stride_cn, dot_out_dtype: tl.constexpr, + allow_tf32: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr ): # matrix multiplication pid = tl.program_id(0) @@ -114,9 +115,10 @@ def _kernel(A, B, C, M, N, K, _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) - a = a.to(C.dtype.element_ty) - b = b.to(C.dtype.element_ty) - acc += tl.dot(a, b, out_dtype=dot_out_dtype) + if AB_DTYPE: + a = a.to(C.dtype.element_ty) + b = b.to(C.dtype.element_ty) + acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk acc = acc.to(C.dtype.element_ty) @@ -138,7 +140,7 @@ class _matmul(torch.autograd.Function): _locks = {} @staticmethod - def _call(a, b, dot_out_dtype): + def _call(a, b, dot_out_dtype, allow_tf32): device = a.device # handle non-contiguous inputs if necessary if a.stride(0) > 1 and a.stride(1) > 1: @@ -150,8 +152,8 @@ def _call(a, b, dot_out_dtype): M, K = a.shape _, N = b.shape # allocates output - if a.dtype in [tl.float8e4, tl.float8e4b15, tl.float8e5] or\ - b.dtype in [tl.float8e4, tl.float8e4b15, tl.float8e5]: + if a.dtype in [tl.float8e4nv, tl.float8e4b15, tl.float8e5] or\ + b.dtype in [tl.float8e4nv, tl.float8e4b15, tl.float8e5]: c_dtype = torch.float16 else: c_dtype = get_higher_dtype(a.dtype, b.dtype) @@ -169,6 +171,9 @@ def _call(a, b, dot_out_dtype): dot_out_dtype = tl.float32 else: dot_out_dtype = tl.int32 + ab_dtype = True + if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]: + ab_dtype = False # launch kernel grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K']) _kernel[grid](a, b, c, M, N, K, @@ -176,12 +181,13 @@ def _call(a, b, dot_out_dtype): b.stride(0), b.stride(1), c.stride(0), c.stride(1), dot_out_dtype=dot_out_dtype, - GROUP_M=8) + allow_tf32=allow_tf32, + GROUP_M=8, AB_DTYPE=ab_dtype) return c @staticmethod - def forward(ctx, a, b, dot_out_dtype=None): - return _matmul._call(a, b, dot_out_dtype=dot_out_dtype) + def forward(ctx, a, b, dot_out_dtype=None, allow_tf32=True): + return _matmul._call(a, b, dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32) matmul = _matmul.apply diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index e5e9d6b2c27b..15bcf53063a4 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -33,7 +33,7 @@ def __init__(self, fn, arg_names, configs, key, verbose, reset_to_zero, prune_co 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. ''' if not configs: - self.configs = [Config({}, num_warps=4, num_stages=2)] + self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)] else: self.configs = configs self.key_idx = [arg_names.index(k) for k in key] @@ -79,7 +79,11 @@ def kernel_call(): if config.pre_hook: config.pre_hook(full_nargs) self.hook(args) - self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) + self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, + num_ctas=config.num_ctas, + enable_warp_specialization=config.enable_warp_specialization, + # enable_persistent=False, + **current) try: return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8)) except OutOfResources: @@ -96,7 +100,7 @@ def get_best_config(self, *args, **kwargs): key_values.append(kwargs[name]) key = tuple(key_values) - return self.cache[key] if key in self.cache else Config() + return self.best_config def run(self, *args, **kwargs): @@ -125,12 +129,12 @@ def run(self, *args, **kwargs): else: config = self.configs[0] self.best_config = config + full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs} if config.pre_hook is not None: - full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs} config.pre_hook(full_nargs) - ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) - self.nargs = None - return ret + return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, + num_ctas=config.num_ctas, + enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs) def prune_configs(self, kwargs): pruned_configs = self.configs @@ -143,10 +147,16 @@ def prune_configs(self, kwargs): if len(pruned_configs) > top_k: est_timing = { config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, - num_warps=config.num_warps) + num_warps=config.num_warps, + num_ctas=config.num_ctas, + enable_warp_specialization=config.enable_warp_specialization, + enable_persistent=config.enable_persistent) for config in pruned_configs } - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + pruned_configs = sorted( + est_timing.keys(), + key=lambda x: est_timing[x])[ + :top_k] return pruned_configs def warmup(self, *args, **kwargs): @@ -155,7 +165,10 @@ def warmup(self, *args, **kwargs): self.fn.warmup( *args, num_warps=config.num_warps, + num_ctas=config.num_ctas, num_stages=config.num_stages, + enable_warp_specialization=config.enable_warp_specialization, + enable_persistent=config.enable_persistent, **kwargs, **config.kwargs, ) @@ -174,15 +187,20 @@ class Config: :type num_warps: int :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. Mostly useful for matrix multiplication workloads on SM80+ GPUs. - :type num_stages: int + :type enable_warp_specialization: bool + :ivar enable_warp_specialization: enable specialization (spatial partitioning) or not. See https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#spatial-partitioning-also-known-as-warp-specialization :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this function are args. """ - def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None): + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, enable_warp_specialization=False, pre_hook=None): self.kwargs = kwargs self.num_warps = num_warps + self.num_ctas = num_ctas self.num_stages = num_stages + self.enable_warp_specialization = enable_warp_specialization + # TODO[shuhaoj]: May make enable_persistent configurable in future if necessay. + self.enable_persistent = False self.pre_hook = pre_hook def __str__(self): @@ -190,7 +208,11 @@ def __str__(self): for k, v in self.kwargs.items(): res.append(f'{k}: {v}') res.append(f'num_warps: {self.num_warps}') + res.append(f'num_ctas: {self.num_ctas}') res.append(f'num_stages: {self.num_stages}') + res.append( + f'enable_warp_specialization: {self.enable_warp_specialization}') + res.append(f'enable_persistent: {self.enable_persistent}') return ', '.join(res) diff --git a/python/triton/runtime/backends/cuda.c b/python/triton/runtime/backends/cuda.c index 009d52679830..622588e8d0ff 100644 --- a/python/triton/runtime/backends/cuda.c +++ b/python/triton/runtime/backends/cuda.c @@ -1,4 +1,5 @@ #include "cuda.h" +#include #define PY_SSIZE_T_CLEAN #include @@ -16,11 +17,172 @@ static inline void gpuAssert(CUresult code, const char *file, int line) { #define CUDA_CHECK(ans) \ { \ - gpuAssert((ans), __FILE__, __LINE__); \ - if (PyErr_Occurred()) \ + { gpuAssert((ans), __FILE__, __LINE__); } \ + } + +#define ADD_ENUM_ITEM(value) \ + do { \ + PyObject *py_value = PyLong_FromLong(value); \ + PyDict_SetItemString(enum_dict, #value, py_value); \ + } while (0) + +#define ADD_ENUM_ITEM_0() +#define ADD_ENUM_ITEM_1(v1) ADD_ENUM_ITEM(v1) +#define ADD_ENUM_ITEM_2(v1, v2) \ + ADD_ENUM_ITEM(v1); \ + ADD_ENUM_ITEM(v2); +#define ADD_ENUM_ITEM_3(v1, v2, v3) \ + ADD_ENUM_ITEM(v1); \ + ADD_ENUM_ITEM(v2); \ + ADD_ENUM_ITEM(v3); +#define ADD_ENUM_ITEM_4(v1, v2, v3, v4) \ + ADD_ENUM_ITEM(v1); \ + ADD_ENUM_ITEM(v2); \ + ADD_ENUM_ITEM(v3); \ + ADD_ENUM_ITEM(v4); +#define ADD_ENUM_ITEM_5(v1, v2, v3, v4, v5) \ + ADD_ENUM_ITEM_2(v1, v2); \ + ADD_ENUM_ITEM_3(v3, v4, v5); +#define ADD_ENUM_ITEM_6(v1, v2, v3, v4, v5, v6) \ + ADD_ENUM_ITEM_2(v1, v2); \ + ADD_ENUM_ITEM_4(v3, v4, v5, v6); +#define ADD_ENUM_ITEM_7(v1, v2, v3, v4, v5, v6, v7) \ + ADD_ENUM_ITEM_3(v1, v2, v3); \ + ADD_ENUM_ITEM_4(v4, v5, v6, v7); +#define ADD_ENUM_ITEM_8(v1, v2, v3, v4, v5, v6, v7, v8) \ + ADD_ENUM_ITEM_4(v1, v2, v3, v4); \ + ADD_ENUM_ITEM_4(v5, v6, v7, v8); +#define ADD_ENUM_ITEM_9(v1, v2, v3, v4, v5, v6, v7, v8, v9) \ + ADD_ENUM_ITEM_5(v1, v2, v3, v4, v5); \ + ADD_ENUM_ITEM_4(v6, v7, v8, v9); +#define ADD_ENUM_ITEM_10(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10) \ + ADD_ENUM_ITEM_5(v1, v2, v3, v4, v5); \ + ADD_ENUM_ITEM_5(v6, v7, v8, v9, v10); +#define ADD_ENUM_ITEM_11(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) \ + ADD_ENUM_ITEM_6(v1, v2, v3, v4, v5, v6); \ + ADD_ENUM_ITEM_5(v7, v8, v9, v10, v11); +#define ADD_ENUM_ITEM_12(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12) \ + ADD_ENUM_ITEM_6(v1, v2, v3, v4, v5, v6); \ + ADD_ENUM_ITEM_6(v7, v8, v9, v10, v11, v12); +#define ADD_ENUM_ITEM_13(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, \ + v13) \ + ADD_ENUM_ITEM_7(v1, v2, v3, v4, v5, v6, v7); \ + ADD_ENUM_ITEM_6(v8, v9, v10, v11, v12, v13); +#define ADD_ENUM_ITEM_14(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, \ + v13, v14) \ + ADD_ENUM_ITEM_7(v1, v2, v3, v4, v5, v6, v7); \ + ADD_ENUM_ITEM_7(v8, v9, v10, v11, v12, v13, v14); + +#define DISPATCH_ARGS_N(_14, _13, _12, _11, _10, _9, _8, _7, _6, _5, _4, _3, \ + _2, _1, N, ...) \ + ADD_ENUM_ITEM_##N +#define DISPATCH_ARGS(...) \ + DISPATCH_ARGS_N(__VA_ARGS__, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, \ + 0) \ + (__VA_ARGS__) + +#define ADD_ENUM_TO_MODULE(module, enum_name, ...) \ + do { \ + PyObject *enum_dict = PyDict_New(); \ + DISPATCH_ARGS(__VA_ARGS__) \ + if (enum_dict != NULL) { \ + PyObject_SetAttrString(module, #enum_name, enum_dict); \ + } \ + } while (0) + +static void defineEnums(PyObject *self) { + ADD_ENUM_TO_MODULE( + self, CUtensorMapDataType, CU_TENSOR_MAP_DATA_TYPE_UINT8, + CU_TENSOR_MAP_DATA_TYPE_UINT16, CU_TENSOR_MAP_DATA_TYPE_UINT32, + CU_TENSOR_MAP_DATA_TYPE_INT32, CU_TENSOR_MAP_DATA_TYPE_UINT64, + CU_TENSOR_MAP_DATA_TYPE_INT64, CU_TENSOR_MAP_DATA_TYPE_FLOAT16, + CU_TENSOR_MAP_DATA_TYPE_FLOAT32, CU_TENSOR_MAP_DATA_TYPE_FLOAT64, + CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ, + CU_TENSOR_MAP_DATA_TYPE_TFLOAT32, CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ); + + ADD_ENUM_TO_MODULE(self, CUtensorMapInterleave, CU_TENSOR_MAP_INTERLEAVE_NONE, + CU_TENSOR_MAP_INTERLEAVE_16B, + CU_TENSOR_MAP_INTERLEAVE_32B); + + ADD_ENUM_TO_MODULE(self, CUtensorMapSwizzle, CU_TENSOR_MAP_SWIZZLE_NONE, + CU_TENSOR_MAP_SWIZZLE_32B, CU_TENSOR_MAP_SWIZZLE_64B, + CU_TENSOR_MAP_SWIZZLE_128B); + + ADD_ENUM_TO_MODULE( + self, CUtensorMapL2promotion, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_L2_PROMOTION_L2_64B, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CU_TENSOR_MAP_L2_PROMOTION_L2_256B); + + ADD_ENUM_TO_MODULE(self, CUtensorMapFloatOOBfill, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA); +} + +typedef struct { + PyObject_HEAD cuuint32_t value; +} PyCUuint32; + +typedef struct { + PyObject_HEAD cuuint64_t value; +} PyCUuint64; + +#define DEFINE_CUUINT_CONSTRUCTOR(NAME, TYPE, FORMAT, VALUE_TYPE) \ + static PyObject *Py##NAME##_New(PyTypeObject *type, PyObject *args, \ + PyObject *kwds) { \ + Py##NAME *self; \ + VALUE_TYPE value; \ + if (!PyArg_ParseTuple(args, FORMAT, &value)) \ return NULL; \ + self = (Py##NAME *)type->tp_alloc(type, 0); \ + if (self != NULL) { \ + self->value = (TYPE)value; \ + } \ + return (PyObject *)self; \ } +DEFINE_CUUINT_CONSTRUCTOR(CUuint32, cuuint32_t, "l", long) +DEFINE_CUUINT_CONSTRUCTOR(CUuint64, cuuint64_t, "L", long long) + +static PyTypeObject PyCUuint32_Type = { + PyVarObject_HEAD_INIT(NULL, 0).tp_name = "cuda_utils.cuuint32_t", + .tp_basicsize = sizeof(PyCUuint32), + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_new = PyCUuint32_New, +}; + +static PyTypeObject PyCUuint64_Type = { + PyVarObject_HEAD_INIT(NULL, 0).tp_name = "cuda_utils.cuuint64_t", + .tp_basicsize = sizeof(PyCUuint64), + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_new = PyCUuint64_New, +}; + +static void defineTypes(PyObject *self) { + if (PyType_Ready(&PyCUuint32_Type) < 0) { + PyErr_SetString(PyExc_TypeError, "Failed to ready cuuint32_t type"); + return; + } + Py_INCREF(&PyCUuint32_Type); + if (PyModule_AddObject(self, "cuuint32_t", (PyObject *)&PyCUuint32_Type) < + 0) { + PyErr_SetString(PyExc_RuntimeError, + "Failed to add cuuint32_t type to module"); + return; + } + + if (PyType_Ready(&PyCUuint64_Type) < 0) { + PyErr_SetString(PyExc_TypeError, "Failed to ready cuuint64_t type"); + return; + } + Py_INCREF(&PyCUuint64_Type); + if (PyModule_AddObject(self, "cuuint64_t", (PyObject *)&PyCUuint64_Type) < + 0) { + PyErr_SetString(PyExc_RuntimeError, + "Failed to add cuuint64_t type to module"); + return; + } +} + static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { int device_id; if (!PyArg_ParseTuple(args, "i", &device_id)) @@ -70,6 +232,8 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) { int32_t n_spills = 0; // create driver handles CUcontext pctx = 0; + + Py_BEGIN_ALLOW_THREADS; CUDA_CHECK(cuCtxGetCurrent(&pctx)); if (!pctx) { CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); @@ -100,6 +264,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) { cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static)); } + Py_END_ALLOW_THREADS; if (PyErr_Occurred()) { return NULL; @@ -108,11 +273,165 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) { n_spills); } +static PyObject *memAlloc(PyObject *self, PyObject *args) { + size_t bytesize; + CUdeviceptr dptr; + CUresult result; + + if (!PyArg_ParseTuple(args, "K", &bytesize)) { + return NULL; // Error parsing arguments + } + + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK(cuMemAlloc(&dptr, bytesize)); + Py_END_ALLOW_THREADS; + + return PyLong_FromUnsignedLongLong((unsigned long long)dptr); +} + +static PyObject *memcpyHtoD(PyObject *self, PyObject *args) { + unsigned long long dstDevicePtr, srcHostPtr; + size_t byteCount; + CUdeviceptr dstDevice; + const void *srcHost; + CUresult result; + + if (!PyArg_ParseTuple(args, "KKK", &dstDevicePtr, &srcHostPtr, &byteCount)) { + return NULL; // Error parsing arguments + } + + dstDevice = (CUdeviceptr)dstDevicePtr; + srcHost = (const void *)srcHostPtr; + + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK(cuMemcpyHtoD(dstDevice, srcHost, byteCount)); + Py_END_ALLOW_THREADS; + + Py_RETURN_NONE; +} + +static PyObject *memFree(PyObject *self, PyObject *args) { + CUdeviceptr dptr; + + if (!PyArg_ParseTuple(args, "K", &dptr)) { + return NULL; // Error parsing arguments + } + + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK(cuMemFree(dptr)); + Py_END_ALLOW_THREADS; + + Py_RETURN_NONE; +} + +// Helper function to convert a Python list to a cuuint64_t array +static cuuint64_t *list_to_cuuint64_array(PyObject *listObj) { + Py_ssize_t len = PyList_Size(listObj); + cuuint64_t *array = malloc(len * sizeof(cuuint64_t)); + for (Py_ssize_t i = 0; i < len; i++) { + PyObject *item = PyList_GetItem(listObj, i); + array[i] = (cuuint64_t)PyLong_AsUnsignedLongLong(item); + } + return array; +} + +// Helper function to convert a Python list to a cuuint32_t array +static cuuint32_t *list_to_cuuint32_array(PyObject *listObj) { + Py_ssize_t len = PyList_Size(listObj); + cuuint32_t *array = malloc(len * sizeof(cuuint32_t)); + for (Py_ssize_t i = 0; i < len; i++) { + PyObject *item = PyList_GetItem(listObj, i); + array[i] = (cuuint32_t)PyLong_AsUnsignedLong(item); + } + return array; +} + +typedef CUresult (*cuTensorMapEncodeTiled_t)( + CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, + const cuuint64_t *globalStrides, const cuuint32_t *boxDim, + const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill); + +static cuTensorMapEncodeTiled_t getCuTensorMapEncodeTiledHandle() { + // Open the shared library + void *handle = dlopen("libcuda.so", RTLD_LAZY); + if (!handle) { + PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so"); + return NULL; + } + // Clear any existing error + dlerror(); + cuTensorMapEncodeTiled_t cuTensorMapEncodeTiledHandle = + (cuTensorMapEncodeTiled_t)dlsym(handle, "cuTensorMapEncodeTiled"); + // Check for errors + const char *dlsym_error = dlerror(); + if (dlsym_error) { + PyErr_SetString( + PyExc_RuntimeError, + "Failed to retrieve cuTensorMapEncodeTiled from libcuda.so"); + return NULL; + } + return cuTensorMapEncodeTiledHandle; +} + +static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) { + CUtensorMap *tensorMap = (CUtensorMap *)malloc(sizeof(CUtensorMap)); + CUtensorMapDataType tensorDataType; + cuuint32_t tensorRank; + void *globalAddress; + PyObject *globalDimObj, *globalStridesObj, *boxDimObj, *elementStridesObj; + CUtensorMapInterleave interleave; + CUtensorMapSwizzle swizzle; + CUtensorMapL2promotion l2Promotion; + CUtensorMapFloatOOBfill oobFill; + + // Parse arguments + if (!PyArg_ParseTuple(args, "iiKO!O!O!O!iiii", &tensorDataType, &tensorRank, + &globalAddress, &PyList_Type, &globalDimObj, + &PyList_Type, &globalStridesObj, &PyList_Type, + &boxDimObj, &PyList_Type, &elementStridesObj, + &interleave, &swizzle, &l2Promotion, &oobFill)) { + return NULL; // Error parsing arguments + } + + // Convert Python lists to C arrays + cuuint64_t *globalDim = list_to_cuuint64_array(globalDimObj); + cuuint64_t *globalStrides = list_to_cuuint64_array(globalStridesObj); + cuuint32_t *boxDim = list_to_cuuint32_array(boxDimObj); + cuuint32_t *elementStrides = list_to_cuuint32_array(elementStridesObj); + + static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiledHandle = NULL; + if (cuTensorMapEncodeTiledHandle == NULL) { + cuTensorMapEncodeTiledHandle = getCuTensorMapEncodeTiledHandle(); + } + // Call the function + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK(cuTensorMapEncodeTiledHandle( + tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, + globalStrides, boxDim, elementStrides, interleave, swizzle, l2Promotion, + oobFill)); + Py_END_ALLOW_THREADS; + + // Clean up + free(globalDim); + free(globalStrides); + free(boxDim); + free(elementStrides); + // Return the tensor map as a normal pointer + return PyLong_FromUnsignedLongLong((unsigned long long)tensorMap); +} + static PyMethodDef ModuleMethods[] = { {"load_binary", loadBinary, METH_VARARGS, "Load provided cubin into CUDA driver"}, {"get_device_properties", getDeviceProperties, METH_VARARGS, "Get the properties for a given device"}, + {"cuMemAlloc", memAlloc, METH_VARARGS}, + {"cuMemcpyHtoD", memcpyHtoD, METH_VARARGS}, + {"cuMemFree", memFree, METH_VARARGS}, + {"cuTensorMapEncodeTiled", tensorMapEncodeTiled, METH_VARARGS}, {NULL, NULL, 0, NULL} // sentinel }; @@ -126,6 +445,10 @@ PyMODINIT_FUNC PyInit_cuda_utils(void) { if (m == NULL) { return NULL; } + + defineEnums(m); + defineTypes(m); PyModule_AddFunctions(m, ModuleMethods); + return m; } diff --git a/python/triton/runtime/cache.py b/python/triton/runtime/cache.py index 43e6660a59df..db8f6193e9ac 100644 --- a/python/triton/runtime/cache.py +++ b/python/triton/runtime/cache.py @@ -40,18 +40,20 @@ def __init__(self, key): self.key = key self.lock_path = None # create cache directory if it doesn't exist - self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir()) + self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir() if self.cache_dir: self.cache_dir = os.path.join(self.cache_dir, self.key) self.lock_path = os.path.join(self.cache_dir, "lock") os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") def _make_path(self, filename) -> str: return os.path.join(self.cache_dir, filename) - def has_file(self, filename): + def has_file(self, filename) -> bool: if not self.cache_dir: - return False + raise RuntimeError("Could not create or locate cache dir") return os.path.exists(self._make_path(filename)) def get_file(self, filename) -> Optional[str]: @@ -80,16 +82,16 @@ def get_group(self, filename: str) -> Optional[Dict[str, str]]: return result # Note a group of pushed files as being part of a group - def put_group(self, filename: str, group: Dict[str, str]): + def put_group(self, filename: str, group: Dict[str, str]) -> str: if not self.cache_dir: - return + raise RuntimeError("Could not create or locate cache dir") grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))}) grp_filename = f"__grp__{filename}" return self.put(grp_contents, grp_filename, binary=False) def put(self, data, filename, binary=True) -> str: if not self.cache_dir: - return + raise RuntimeError("Could not create or locate cache dir") binary = isinstance(data, bytes) if not binary: data = str(data) diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index 3850821536c5..5b778be4dc10 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -52,6 +52,15 @@ def __init__(self): spec.loader.exec_module(mod) self.load_binary = mod.load_binary self.get_device_properties = mod.get_device_properties + self.CUtensorMapDataType = mod.CUtensorMapDataType + self.CUtensorMapInterleave = mod.CUtensorMapInterleave + self.CUtensorMapSwizzle = mod.CUtensorMapSwizzle + self.CUtensorMapL2promotion = mod.CUtensorMapL2promotion + self.CUtensorMapFloatOOBfill = mod.CUtensorMapFloatOOBfill + self.cuTensorMapEncodeTiled = mod.cuTensorMapEncodeTiled + self.cuMemAlloc = mod.cuMemAlloc + self.cuMemcpyHtoD = mod.cuMemcpyHtoD + self.cuMemFree = mod.cuMemFree class CudaDriver(DriverBase): diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index a27c7e28dbc3..6dc9de5a7c60 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -11,7 +11,9 @@ from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast, overload) +from .._C.libtriton.triton import TMAInfos from ..common.backend import get_backend, path_to_ptxas +from ..language.core import dtype TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) TRITON_VERSION = "2.1.0" @@ -59,7 +61,7 @@ class DependenciesFinder(ast.NodeVisitor): def __init__(self, globals, src) -> None: super().__init__() - self.ret = hashlib.md5(src.encode("utf-8")).hexdigest() + self.ret = hashlib.sha1(src.encode("utf-8")).hexdigest() self.globals = globals def visit_Name(self, node): @@ -89,7 +91,7 @@ def visit_Call(self, node): func.hash = finder.ret noinline = str(getattr(func, 'noinline', False)) self.ret = (self.ret + func.hash + noinline).encode("utf-8") - self.ret = hashlib.md5(self.ret).hexdigest() + self.ret = hashlib.sha1(self.ret).hexdigest() # ----------------------------------------------------------------------------- # JITFunction @@ -102,23 +104,29 @@ def version_key(): contents = [] # frontend with open(__file__, "rb") as f: - contents += [hashlib.md5(f.read()).hexdigest()] + contents += [hashlib.sha1(f.read()).hexdigest()] # compiler compiler_path = os.path.join(TRITON_PATH, 'compiler') for lib in pkgutil.iter_modules([compiler_path]): with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: - contents += [hashlib.md5(f.read()).hexdigest()] + contents += [hashlib.sha1(f.read()).hexdigest()] # backend + libtriton_hash = hashlib.sha1() with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: - contents += [hashlib.md5(f.read()).hexdigest()] + while True: + chunk = f.read(1024 ** 2) + if not chunk: + break + libtriton_hash.update(chunk) + contents.append(libtriton_hash.hexdigest()) # language language_path = os.path.join(TRITON_PATH, 'language') for lib in pkgutil.iter_modules([language_path]): with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: - contents += [hashlib.md5(f.read()).hexdigest()] + contents += [hashlib.sha1(f.read()).hexdigest()] # ptxas version ptxas = path_to_ptxas()[0] - ptxas_version = hashlib.md5(subprocess.check_output([ptxas, "--version"])).hexdigest() + ptxas_version = hashlib.sha1(subprocess.check_output([ptxas, "--version"])).hexdigest() return '-'.join(TRITON_VERSION) + '-' + ptxas_version + '-' + '-'.join(contents) @@ -147,6 +155,11 @@ class JITFunction(KernelInterface[T]): # Hook for inspecting compiled functions and modules cache_hook = None divisibility = 16 + # As Hopper TMA load and store primitive requires the tensor stride to be 16-byte aligned. + # And we only support WGMMA with float16 dtype on Hopper for now. + # So whether the LoadOp and StoreOp will lowering into TMA copy depend on whether the tensor stride is divisible by 8. + # TODO: Make it more reasonable to handle multiple dtypes. + divisibility_8 = 8 @staticmethod def _key_of(arg): @@ -201,10 +214,29 @@ def is_divisible_by_16(x): if x is None: return True return False - divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize} - equal_to_1 = {i for i, arg in enumerate(args) if not isinstance(arg, bool) and isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize} - return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1)) - # return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1) + + def is_divisible_by_8(x): + if isinstance(x, int): + return x % JITFunction.divisibility_8 == 0 + if x is None: + return True + return False + divisible_by_16 = {i for i, arg in enumerate( + args) if is_divisible_by_16(arg) and i not in self.do_not_specialize} + divisible_by_8 = {i for i, arg in enumerate( + args) if is_divisible_by_8(arg) and i not in self.do_not_specialize} + equal_to_1 = { + i for i, arg in enumerate(args) if isinstance( + arg, int) and not isinstance( + arg, bool) and arg == 1 and i not in self.do_not_specialize} + # folded equal_to_1 and None + # TODO: method to collect all folded args + none_args = {i for i, arg in enumerate(args) if arg is None and i not in self.do_not_specialize} + ids_of_folded_args = equal_to_1 | none_args + return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])( + tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), tuple(divisible_by_8)) + # return _triton.code_gen.instance_descriptor(divisible_by_16, + # equal_to_1) @staticmethod def _type_of(key): @@ -214,9 +246,10 @@ def _type_of(key): dtype_str = str(key).split(".")[-1] tys = { "bool": "i1", - "float8e4": "fp8e4", + "float8e4nv": "fp8e4nv", "float8e5": "fp8e5", "float8e4b15": "fp8e4b15", + "float8e4b15x4": "fp8e4b15x4", "float16": "fp16", "bfloat16": "bf16", "float32": "fp32", @@ -243,13 +276,13 @@ def _make_constants(self, constexpr_key): constants = dict(zip(self.constexprs, constexpr_key)) return constants - def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs): + def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs): if JITFunction.cache_hook is None: return False name = self.fn.__name__ module = self.fn.__module__ arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])]) - repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})" + repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})" key = str(key) class LegacyCompiler: @@ -259,21 +292,22 @@ def __init__(self, module, name): pass kwargs = dict(signature=signature, device=device, constants=constants, - num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, + num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs) - return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False) + return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={ + "key": key, **kwargs}, is_manual_warmup=False, already_compiled=False) def _get_arg_specialization_key(self, arg) -> str: arg_annotation = self.__annotations__.get(arg, '') if arg_annotation == '': return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") \ - else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) \ + else ({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1) if isinstance({arg}, int) \ else (False,)' elif 'Tensor' in arg_annotation: return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)' elif arg_annotation == 'int': - return f'({arg} % {JITFunction.divisibility} == 0, {arg} == 1)' + return f'({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1)' else: return '(False,)' @@ -304,8 +338,11 @@ def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: Li return device_types[0] if len(device_types) > 0 else 'cuda' def _make_launcher(self): - regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs] - constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs] + regular_args = [f'{arg}' for i, arg in enumerate( + self.arg_names) if i not in self.constexprs] + constexpr_args = [ + f'{arg}' for i, arg in enumerate( + self.arg_names) if i in self.constexprs] args = ', '.join(regular_args) # cache key for regular argument type sig_keys = ', '.join([self._get_arg_sig_key(arg) for arg in regular_args]) @@ -322,19 +359,17 @@ def _make_launcher(self): spec_keys = ', '.join(specializations) grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names]) - args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults)) + args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = triton.language.dtype(\'{dflt}\')' if dtype.is_dtype(f'{dflt}') else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults)) + args_signature = args_signature + ', ' if len(args_signature) > 0 else '' src = f""" - -def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None): - from ..compiler import compile, CompiledKernel - sig_key = {sig_keys}, +import triton +def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None): + from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages + sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()} constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()} spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()} - key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_stages, self.debug) - if not extern_libs is None: - key = (key, tuple(extern_libs.items())) - assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2" + assert num_ctas > 0 assert grid is not None if callable(grid): grid = grid({{{grid_args}}}) @@ -366,16 +401,29 @@ def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, e else: stream = device_backend.get_stream() + if num_warps is None: + num_warps = get_arch_default_num_warps(device_type) + if num_stages is None: + num_stages = get_arch_default_num_stages(device_type) + + key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, self.debug) + if not extern_libs is None: + key = (key, tuple(extern_libs.items())) + bin = cache[device].get(key, None) if bin is not None: + # build dict of constant values + args = [{args}] + # Create tensormaps and append to args + args = bin.assemble_tensormap_to_arg(args) if not warmup: - bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, {args}) + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args) return bin # kernel not cached -- compile else: # build dict of constant values args = [{args}] - all_args = {', '.join([f'{arg}' for arg in self.arg_names])}, + all_args = {', '.join([f'{arg}' for arg in self.arg_names]) + ', ' if len(self.arg_names) > 0 else ()} configs = self._get_config(*all_args), constants = self._make_constants(constexpr_key) constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}}) @@ -386,10 +434,12 @@ def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, e for i, arg in constants.items(): if callable(arg): raise TypeError(f"Callable constexpr at index {{i}} is not supported") - if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): - bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type) + if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs): + bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type) + # Create tensormaps and append to args + args = bin.assemble_tensormap_to_arg(args) if not warmup: - bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args) + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args) self.cache[device][key] = bin return bin return None @@ -418,9 +468,6 @@ def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinlin self.arg_names = [v.name for v in signature.parameters.values()] self.arg_defaults = [v.default for v in signature.parameters.values()] self.has_defaults = any(v != inspect._empty for v in self.arg_defaults) - # specialization hints - self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize - self.do_not_specialize = {self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize} # function source code (without decorators) self.src = textwrap.dedent(inspect.getsource(fn)) self.src = self.src[self.src.find("def"):] @@ -437,6 +484,12 @@ def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinlin self.__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()} # index of constexprs self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty] + # specialization hints + regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs] + self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize + self.do_not_specialize = {regular_args.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize} + # tma info + self.tensormaps_info = TMAInfos() # launcher self.run = self._make_launcher() # re-use docs of wrapped function @@ -594,6 +647,9 @@ def stride(self, i): def __str__(self) -> str: return f'TensorWrapper[{self.dtype}]({self.base})' + def element_size(self): + return self.base.element_size() + def reinterpret(tensor, dtype): if isinstance(tensor, TensorWrapper): diff --git a/python/triton/testing.py b/python/triton/testing.py index d48264291f30..c4357bd243c0 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -3,7 +3,9 @@ import subprocess import sys from contextlib import contextmanager +from typing import Any, Dict, List +from . import language as tl from ._C.libtriton.triton import runtime @@ -200,37 +202,41 @@ class Benchmark: def __init__( self, - x_names, - x_vals, - line_arg, - line_vals, - line_names, - plot_name, - args, - xlabel='', - ylabel='', - x_log=False, - y_log=False, + x_names: List[str], + x_vals: List[Any], + line_arg: str, + line_vals: List[Any], + line_names: List[str], + plot_name: str, + args: Dict[str, Any], + xlabel: str = '', + ylabel: str = '', + x_log: bool = False, + y_log: bool = False, color=None, styles=None, ): """ - Constructor + Constructor. + x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list + of scalars and there are multiple x_names, all arguments will have the same value. + If x_vals is a list of tuples/lists, each element should have the same length as + x_names. - :param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value. + :param x_names: Name of the arguments that should appear on the x axis of the plot. :type x_names: List[str] :param x_vals: List of values to use for the arguments in :code:`x_names`. :type x_vals: List[Any] :param line_arg: Argument name for which different values correspond to different lines in the plot. :type line_arg: str :param line_vals: List of values to use for the arguments in :code:`line_arg`. - :type line_vals: List[str] + :type line_vals: List[Any] :param line_names: Label names for the different lines. :type line_names: List[str] :param plot_name: Name of the plot. :type plot_name: str - :param args: List of arguments to remain fixed throughout the benchmark. - :type args: List[str] + :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark. + :type args: Dict[str, Any] :param xlabel: Label for the x axis of the plot. :type xlabel: str, optional :param ylabel: Label for the y axis of the plot. @@ -260,7 +266,7 @@ def __init__(self, fn, benchmarks): self.fn = fn self.benchmarks = benchmarks - def _run(self, bench, save_path, show_plots, print_data): + def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool): import os import matplotlib.pyplot as plt @@ -268,9 +274,17 @@ def _run(self, bench, save_path, show_plots, print_data): y_mean = bench.line_names y_min = [f'{x}-min' for x in bench.line_names] y_max = [f'{x}-max' for x in bench.line_names] - df = pd.DataFrame(columns=[bench.x_names[0]] + y_mean + y_min + y_max) + x_names = list(bench.x_names) + df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max) for x in bench.x_vals: - x_args = {x_name: x for x_name in bench.x_names} + # x can be a single value or a sequence of values. + if not isinstance(x, (list, tuple)): + x = [x for _ in x_names] + + if len(x) != len(x_names): + raise ValueError(f"Expected {len(x_names)} values, got {x}") + x_args = dict(zip(x_names, x)) + row_mean, row_min, row_max = [], [], [] for y in bench.line_vals: ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args) @@ -281,21 +295,24 @@ def _run(self, bench, save_path, show_plots, print_data): row_mean += [y_mean] row_min += [y_min] row_max += [y_max] - df.loc[len(df)] = [x] + row_mean + row_min + row_max + df.loc[len(df)] = list(x) + row_mean + row_min + row_max + if bench.plot_name: plt.figure() ax = plt.subplot() - x = bench.x_names[0] + # Plot first x value on x axis if there are multiple. + first_x = x_names[0] for i, y in enumerate(bench.line_names): y_min, y_max = df[y + '-min'], df[y + '-max'] col = bench.styles[i][0] if bench.styles else None sty = bench.styles[i][1] if bench.styles else None - ax.plot(df[x], df[y], label=y, color=col, ls=sty) - if y_min is not None and y_max is not None: - ax.fill_between(df[x], y_min, y_max, alpha=0.15, color=col) + ax.plot(df[first_x], df[y], label=y, color=col, ls=sty) + if not y_min.isnull().all() and not y_max.isnull().all(): + y_min = y_min.astype(float) + y_max = y_max.astype(float) + ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col) ax.legend() - xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names) - ax.set_xlabel(xlabel) + ax.set_xlabel(bench.xlabel or first_x) ax.set_ylabel(bench.ylabel) # ax.set_title(bench.plot_name) ax.set_xscale("log" if bench.x_log else "linear") @@ -304,7 +321,7 @@ def _run(self, bench, save_path, show_plots, print_data): plt.show() if save_path: plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) - df = df[[bench.x_names[0]] + bench.line_names] + df = df[x_names + bench.line_names] if print_data: print(bench.plot_name + ':') print(df) @@ -368,11 +385,11 @@ def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None) assert dtype == torch.float16 ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores else: - if dtype == torch.float32: + if dtype in [torch.float32, torch.int32]: ops_per_sub_core = 256 - elif dtype in [torch.float16, torch.bfloat16]: + elif dtype in [torch.float16, torch.bfloat16, torch.int16]: ops_per_sub_core = 512 - elif dtype == torch.int8: + elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]: ops_per_sub_core = 1024 else: raise RuntimeError("dtype not supported") diff --git a/python/triton/third_party/cuda/bin/ptxas b/python/triton/third_party/cuda/bin/ptxas deleted file mode 100755 index 8b47936ea212..000000000000 Binary files a/python/triton/third_party/cuda/bin/ptxas and /dev/null differ diff --git a/python/triton/third_party/cuda/include/cuda.h b/python/triton/third_party/cuda/include/cuda.h index c713bf316a16..ec111be805af 100755 --- a/python/triton/third_party/cuda/include/cuda.h +++ b/python/triton/third_party/cuda/include/cuda.h @@ -1,5 +1,5 @@ /* - * Copyright 1993-2018 NVIDIA Corporation. All rights reserved. + * Copyright 1993-2022 NVIDIA Corporation. All rights reserved. * * NOTICE TO LICENSEE: * @@ -144,7 +144,23 @@ typedef uint64_t cuuint64_t; #define cuDevicePrimaryCtxSetFlags cuDevicePrimaryCtxSetFlags_v2 #define cuDeviceGetUuid_v2 cuDeviceGetUuid_v2 #define cuIpcOpenMemHandle cuIpcOpenMemHandle_v2 -#define cuGraphInstantiate cuGraphInstantiate_v2 + +#define cuGraphInstantiate cuGraphInstantiateWithFlags + +#define cuGraphExecUpdate cuGraphExecUpdate_v2 +#define cuGetProcAddress cuGetProcAddress_v2 +#define cuGraphAddKernelNode cuGraphAddKernelNode_v2 +#define cuGraphKernelNodeGetParams cuGraphKernelNodeGetParams_v2 +#define cuGraphKernelNodeSetParams cuGraphKernelNodeSetParams_v2 +#define cuGraphExecKernelNodeSetParams cuGraphExecKernelNodeSetParams_v2 + +#define cuStreamWriteValue32 __CUDA_API_PTSZ(cuStreamWriteValue32_v2) +#define cuStreamWaitValue32 __CUDA_API_PTSZ(cuStreamWaitValue32_v2) +#define cuStreamWriteValue64 __CUDA_API_PTSZ(cuStreamWriteValue64_v2) +#define cuStreamWaitValue64 __CUDA_API_PTSZ(cuStreamWaitValue64_v2) +#define cuStreamBatchMemOp __CUDA_API_PTSZ(cuStreamBatchMemOp_v2) +#define cuStreamGetCaptureInfo __CUDA_API_PTSZ(cuStreamGetCaptureInfo_v2) +#define cuStreamGetCaptureInfo_v2 __CUDA_API_PTSZ(cuStreamGetCaptureInfo_v2) #if defined(__CUDA_API_PER_THREAD_DEFAULT_STREAM) #define cuMemcpy __CUDA_API_PTDS(cuMemcpy) @@ -163,13 +179,12 @@ typedef uint64_t cuuint64_t; #define cuMemsetD2D32Async __CUDA_API_PTSZ(cuMemsetD2D32Async) #define cuStreamGetPriority __CUDA_API_PTSZ(cuStreamGetPriority) + #define cuStreamGetId __CUDA_API_PTSZ(cuStreamGetId) #define cuStreamGetFlags __CUDA_API_PTSZ(cuStreamGetFlags) #define cuStreamGetCtx __CUDA_API_PTSZ(cuStreamGetCtx) #define cuStreamWaitEvent __CUDA_API_PTSZ(cuStreamWaitEvent) #define cuStreamEndCapture __CUDA_API_PTSZ(cuStreamEndCapture) #define cuStreamIsCapturing __CUDA_API_PTSZ(cuStreamIsCapturing) - #define cuStreamGetCaptureInfo __CUDA_API_PTSZ(cuStreamGetCaptureInfo) - #define cuStreamGetCaptureInfo_v2 __CUDA_API_PTSZ(cuStreamGetCaptureInfo_v2) #define cuStreamUpdateCaptureDependencies __CUDA_API_PTSZ(cuStreamUpdateCaptureDependencies) #define cuStreamAddCallback __CUDA_API_PTSZ(cuStreamAddCallback) #define cuStreamAttachMemAsync __CUDA_API_PTSZ(cuStreamAttachMemAsync) @@ -178,24 +193,17 @@ typedef uint64_t cuuint64_t; #define cuEventRecord __CUDA_API_PTSZ(cuEventRecord) #define cuEventRecordWithFlags __CUDA_API_PTSZ(cuEventRecordWithFlags) #define cuLaunchKernel __CUDA_API_PTSZ(cuLaunchKernel) - - - + #define cuLaunchKernelEx __CUDA_API_PTSZ(cuLaunchKernelEx) #define cuLaunchHostFunc __CUDA_API_PTSZ(cuLaunchHostFunc) #define cuGraphicsMapResources __CUDA_API_PTSZ(cuGraphicsMapResources) #define cuGraphicsUnmapResources __CUDA_API_PTSZ(cuGraphicsUnmapResources) - #define cuStreamWriteValue32 __CUDA_API_PTSZ(cuStreamWriteValue32) - #define cuStreamWaitValue32 __CUDA_API_PTSZ(cuStreamWaitValue32) - #define cuStreamWriteValue64 __CUDA_API_PTSZ(cuStreamWriteValue64) - #define cuStreamWaitValue64 __CUDA_API_PTSZ(cuStreamWaitValue64) - #define cuStreamBatchMemOp __CUDA_API_PTSZ(cuStreamBatchMemOp) - #define cuLaunchCooperativeKernel __CUDA_API_PTSZ(cuLaunchCooperativeKernel) #define cuSignalExternalSemaphoresAsync __CUDA_API_PTSZ(cuSignalExternalSemaphoresAsync) #define cuWaitExternalSemaphoresAsync __CUDA_API_PTSZ(cuWaitExternalSemaphoresAsync) + #define cuGraphInstantiateWithParams __CUDA_API_PTSZ(cuGraphInstantiateWithParams) #define cuGraphUpload __CUDA_API_PTSZ(cuGraphUpload) #define cuGraphLaunch __CUDA_API_PTSZ(cuGraphLaunch) #define cuStreamCopyAttributes __CUDA_API_PTSZ(cuStreamCopyAttributes) @@ -229,7 +237,7 @@ typedef uint64_t cuuint64_t; /** * CUDA API version number */ -#define CUDA_VERSION 11060 +#define CUDA_VERSION 12010 #ifdef __cplusplus extern "C" { @@ -251,6 +259,8 @@ typedef CUdevice_v1 CUdevice; /**< CUDA device */ typedef struct CUctx_st *CUcontext; /**< CUDA context */ typedef struct CUmod_st *CUmodule; /**< CUDA module */ typedef struct CUfunc_st *CUfunction; /**< CUDA function */ +typedef struct CUlib_st *CUlibrary; /**< CUDA library */ +typedef struct CUkern_st *CUkernel; /**< CUDA kernel */ typedef struct CUarray_st *CUarray; /**< CUDA array */ typedef struct CUmipmappedArray_st *CUmipmappedArray; /**< CUDA mipmapped array */ typedef struct CUtexref_st *CUtexref; /**< CUDA texture reference */ @@ -331,9 +341,43 @@ typedef enum CUctx_flags_enum { * and it no longer has any effect. All contexts * as of CUDA 3.2 behave as though the flag is enabled. */ CU_CTX_LMEM_RESIZE_TO_MAX = 0x10, /**< Keep local memory allocation after launch */ - CU_CTX_FLAGS_MASK = 0x1f + CU_CTX_COREDUMP_ENABLE = 0x20, /**< Trigger coredumps from exceptions in this context */ + CU_CTX_USER_COREDUMP_ENABLE= 0x40, /**< Enable user pipe to trigger coredumps in this context */ + CU_CTX_SYNC_MEMOPS = 0x80, /**< Force synchronous blocking on cudaMemcpy/cudaMemset */ + CU_CTX_FLAGS_MASK = 0xFF } CUctx_flags; +/** + * Event sched flags + */ +typedef enum CUevent_sched_flags_enum { + CU_EVENT_SCHED_AUTO = 0x00, /**< Automatic scheduling */ + CU_EVENT_SCHED_SPIN = 0x01, /**< Set spin as default scheduling */ + CU_EVENT_SCHED_YIELD = 0x02, /**< Set yield as default scheduling */ + CU_EVENT_SCHED_BLOCKING_SYNC = 0x04, /**< Set blocking synchronization as default scheduling */ +} CUevent_sched_flags; + +/** + * NVCL event scheduling flags + */ +typedef enum cl_event_flags_enum { + NVCL_EVENT_SCHED_AUTO = 0x00, /**< Automatic scheduling */ + NVCL_EVENT_SCHED_SPIN = 0x01, /**< Set spin as default scheduling */ + NVCL_EVENT_SCHED_YIELD = 0x02, /**< Set yield as default scheduling */ + NVCL_EVENT_SCHED_BLOCKING_SYNC = 0x04, /**< Set blocking synchronization as default scheduling */ +} cl_event_flags; + +/** + * NVCL context scheduling flags + */ +typedef enum cl_context_flags_enum { + NVCL_CTX_SCHED_AUTO = 0x00, /**< Automatic scheduling */ + NVCL_CTX_SCHED_SPIN = 0x01, /**< Set spin as default scheduling */ + NVCL_CTX_SCHED_YIELD = 0x02, /**< Set yield as default scheduling */ + NVCL_CTX_SCHED_BLOCKING_SYNC = 0x04, /**< Set blocking synchronization as default scheduling */ +} cl_context_flags; + + /** * Stream creation flags */ @@ -412,7 +456,7 @@ typedef enum CUstreamWaitValue_flags_enum { two remote writes arrive in a defined order, the wait is satisfied by the second write, and downstream work needs to observe the first write. Support for this operation is restricted to selected platforms and can be - queried with ::CU_DEVICE_ATTRIBUTE_CAN_USE_WAIT_VALUE_FLUSH.*/ + queried with ::CU_DEVICE_ATTRIBUTE_CAN_FLUSH_REMOTE_WRITES.*/ } CUstreamWaitValue_flags; /** @@ -425,7 +469,8 @@ typedef enum CUstreamWriteValue_flags_enum { ::cuStreamWriteValue32 will provide a memory fence before the write, which has similar semantics to __threadfence_system() but is scoped to the stream - rather than a CUDA thread. */ + rather than a CUDA thread. + This flag is not supported in the v2 API. */ } CUstreamWriteValue_flags; /** @@ -436,10 +481,19 @@ typedef enum CUstreamBatchMemOpType_enum { CU_STREAM_MEM_OP_WRITE_VALUE_32 = 2, /**< Represents a ::cuStreamWriteValue32 operation */ CU_STREAM_MEM_OP_WAIT_VALUE_64 = 4, /**< Represents a ::cuStreamWaitValue64 operation */ CU_STREAM_MEM_OP_WRITE_VALUE_64 = 5, /**< Represents a ::cuStreamWriteValue64 operation */ + CU_STREAM_MEM_OP_BARRIER = 6, /**< Insert a memory barrier of the specified type */ CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES = 3 /**< This has the same effect as ::CU_STREAM_WAIT_VALUE_FLUSH, but as a standalone operation. */ } CUstreamBatchMemOpType; +/** + * Flags for ::cuStreamMemoryBarrier + */ +typedef enum CUstreamMemoryBarrier_flags_enum { + CU_STREAM_MEMORY_BARRIER_TYPE_SYS = 0x0, /**< System-wide memory barrier. */ + CU_STREAM_MEMORY_BARRIER_TYPE_GPU = 0x1 /**< Limit memory barrier scope to the GPU. */ +} CUstreamMemoryBarrier_flags; + /** * Per-operation parameters for ::cuStreamBatchMemOp */ @@ -469,10 +523,21 @@ typedef union CUstreamBatchMemOpParams_union { CUstreamBatchMemOpType operation; unsigned int flags; } flushRemoteWrites; + struct CUstreamMemOpMemoryBarrierParams_st { /**< Only supported in the _v2 API */ + CUstreamBatchMemOpType operation; + unsigned int flags; + } memoryBarrier; cuuint64_t pad[6]; } CUstreamBatchMemOpParams_v1; typedef CUstreamBatchMemOpParams_v1 CUstreamBatchMemOpParams; +typedef struct CUDA_BATCH_MEM_OP_NODE_PARAMS_st { + CUcontext ctx; + unsigned int count; + CUstreamBatchMemOpParams *paramArray; + unsigned int flags; +} CUDA_BATCH_MEM_OP_NODE_PARAMS; + /** * Occupancy calculator flag */ @@ -648,9 +713,9 @@ typedef enum CUdevice_attribute_enum { CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS = 89, /**< Device can coherently access managed memory concurrently with the CPU */ CU_DEVICE_ATTRIBUTE_COMPUTE_PREEMPTION_SUPPORTED = 90, /**< Device supports compute preemption. */ CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM = 91, /**< Device can access host registered memory at the same virtual address as the CPU */ - CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_MEM_OPS = 92, /**< ::cuStreamBatchMemOp and related APIs are supported. */ - CU_DEVICE_ATTRIBUTE_CAN_USE_64_BIT_STREAM_MEM_OPS = 93, /**< 64-bit operations are supported in ::cuStreamBatchMemOp and related APIs. */ - CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR = 94, /**< ::CU_STREAM_WAIT_VALUE_NOR is supported. */ + CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_MEM_OPS_V1 = 92, /**< Deprecated, along with v1 MemOps API, ::cuStreamBatchMemOp and related APIs are supported. */ + CU_DEVICE_ATTRIBUTE_CAN_USE_64_BIT_STREAM_MEM_OPS_V1 = 93, /**< Deprecated, along with v1 MemOps API, 64-bit operations are supported in ::cuStreamBatchMemOp and related APIs. */ + CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR_V1 = 94, /**< Deprecated, along with v1 MemOps API, ::CU_STREAM_WAIT_VALUE_NOR is supported. */ CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH = 95, /**< Device supports launching cooperative kernels via ::cuLaunchCooperativeKernel */ CU_DEVICE_ATTRIBUTE_COOPERATIVE_MULTI_DEVICE_LAUNCH = 96, /**< Deprecated, ::cuLaunchCooperativeKernelMultiDevice is deprecated. */ CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN = 97, /**< Maximum optin shared memory per block */ @@ -677,12 +742,18 @@ typedef enum CUdevice_attribute_enum { CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_FLUSH_WRITES_OPTIONS = 117, /**< The returned attribute shall be interpreted as a bitmask, where the individual bits are described by the ::CUflushGPUDirectRDMAWritesOptions enum */ CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WRITES_ORDERING = 118, /**< GPUDirect RDMA writes to the device do not need to be flushed for consumers within the scope indicated by the returned attribute. See ::CUGPUDirectRDMAWritesOrdering for the numerical values returned here. */ CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES = 119, /**< Handle types supported with mempool based IPC */ - - - - + CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH = 120, /**< Indicates device supports cluster launch */ CU_DEVICE_ATTRIBUTE_DEFERRED_MAPPING_CUDA_ARRAY_SUPPORTED = 121, /**< Device supports deferred mapping CUDA arrays and CUDA mipmapped arrays */ - + CU_DEVICE_ATTRIBUTE_CAN_USE_64_BIT_STREAM_MEM_OPS = 122, /**< 64-bit operations are supported in ::cuStreamBatchMemOp and related MemOp APIs. */ + CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR = 123, /**< ::CU_STREAM_WAIT_VALUE_NOR is supported by MemOp APIs. */ + CU_DEVICE_ATTRIBUTE_DMA_BUF_SUPPORTED = 124, /**< Device supports buffer sharing with dma_buf mechanism. */ + CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED = 125, /**< Device supports IPC Events. */ + CU_DEVICE_ATTRIBUTE_MEM_SYNC_DOMAIN_COUNT = 126, /**< Number of memory domains the device supports. */ + CU_DEVICE_ATTRIBUTE_TENSOR_MAP_ACCESS_SUPPORTED = 127, /**< Device supports accessing memory using Tensor Map. */ + CU_DEVICE_ATTRIBUTE_UNIFIED_FUNCTION_POINTERS = 129, /**< Device supports unified function pointers. */ + CU_DEVICE_ATTRIBUTE_IS_NUMA_NODE = 130, + CU_DEVICE_ATTRIBUTE_NUMA_ID = 131, + CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED = 132, /**< Device supports switch multicast and reduction operations. */ CU_DEVICE_ATTRIBUTE_MAX } CUdevice_attribute; @@ -724,6 +795,10 @@ typedef enum CUpointer_attribute_enum { CU_POINTER_ATTRIBUTE_IS_GPU_DIRECT_RDMA_CAPABLE = 15, /**< 1 if the memory this pointer is referencing can be used with the GPUDirect RDMA API **/ CU_POINTER_ATTRIBUTE_ACCESS_FLAGS = 16, /**< Returns the access flags the device associated with the current context has on the corresponding memory referenced by the pointer given */ CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE = 17 /**< Returns the mempool handle for the allocation if it was allocated from a mempool. Otherwise returns NULL. **/ + , + CU_POINTER_ATTRIBUTE_MAPPING_SIZE = 18, /**< Size of the actual underlying mapping that the pointer belongs to **/ + CU_POINTER_ATTRIBUTE_MAPPING_BASE_ADDR = 19, /**< The start address of the mapping that the pointer belongs to **/ + CU_POINTER_ATTRIBUTE_MEMORY_BLOCK_ID = 20 /**< A process-wide unique id corresponding to the physical allocation the pointer belongs to **/ } CUpointer_attribute; /** @@ -788,7 +863,7 @@ typedef enum CUfunction_attribute_enum { * The maximum size in bytes of dynamically-allocated shared memory that can be used by * this function. If the user-specified dynamic shared memory size is larger than this * value, the launch will fail. - * See ::cuFuncSetAttribute + * See ::cuFuncSetAttribute, ::cuKernelSetAttribute */ CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES = 8, @@ -797,80 +872,78 @@ typedef enum CUfunction_attribute_enum { * this sets the shared memory carveout preference, in percent of the total shared memory. * Refer to ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR. * This is only a hint, and the driver can choose a different ratio if required to execute the function. - * See ::cuFuncSetAttribute + * See ::cuFuncSetAttribute, ::cuKernelSetAttribute */ CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT = 9, + /** + * If this attribute is set, the kernel must launch with a valid cluster + * size specified. + * See ::cuFuncSetAttribute, ::cuKernelSetAttribute + */ + CU_FUNC_ATTRIBUTE_CLUSTER_SIZE_MUST_BE_SET = 10, + /** + * The required cluster width in blocks. The values must either all be 0 or + * all be positive. The validity of the cluster dimensions is otherwise + * checked at launch time. + * + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime will return CUDA_ERROR_NOT_PERMITTED. + * See ::cuFuncSetAttribute, ::cuKernelSetAttribute + */ + CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_WIDTH = 11, + /** + * The required cluster height in blocks. The values must either all be 0 or + * all be positive. The validity of the cluster dimensions is otherwise + * checked at launch time. + * + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime should return CUDA_ERROR_NOT_PERMITTED. + * See ::cuFuncSetAttribute, ::cuKernelSetAttribute + */ + CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_HEIGHT = 12, + /** + * The required cluster depth in blocks. The values must either all be 0 or + * all be positive. The validity of the cluster dimensions is otherwise + * checked at launch time. + * + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime should return CUDA_ERROR_NOT_PERMITTED. + * See ::cuFuncSetAttribute, ::cuKernelSetAttribute + */ + CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_DEPTH = 13, + /** + * Whether the function can be launched with non-portable cluster size. 1 is + * allowed, 0 is disallowed. A non-portable cluster size may only function + * on the specific SKUs the program is tested on. The launch might fail if + * the program is run on a different hardware platform. + * + * CUDA API provides cudaOccupancyMaxActiveClusters to assist with checking + * whether the desired size can be launched on the current device. + * + * Portable Cluster Size + * + * A portable cluster size is guaranteed to be functional on all compute + * capabilities higher than the target compute capability. The portable + * cluster size for sm_90 is 8 blocks per cluster. This value may increase + * for future compute capabilities. + * + * The specific hardware unit may support higher cluster sizes that’s not + * guaranteed to be portable. + * See ::cuFuncSetAttribute, ::cuKernelSetAttribute + */ + CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED = 14, - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + /** + * The block scheduling policy of a function. The value type is + * CUclusterSchedulingPolicy / cudaClusterSchedulingPolicy. + * See ::cuFuncSetAttribute, ::cuKernelSetAttribute + */ + CU_FUNC_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE = 15, CU_FUNC_ATTRIBUTE_MAX } CUfunction_attribute; @@ -895,7 +968,7 @@ typedef enum CUsharedconfig_enum { } CUsharedconfig; /** - * Shared memory carveout configurations. These may be passed to ::cuFuncSetAttribute + * Shared memory carveout configurations. These may be passed to ::cuFuncSetAttribute or ::cuKernelSetAttribute */ typedef enum CUshared_carveout_enum { CU_SHAREDMEM_CARVEOUT_DEFAULT = -1, /**< No preference for shared memory or L1 (default) */ @@ -926,7 +999,7 @@ typedef enum CUcomputemode_enum { * Memory advise values */ typedef enum CUmem_advise_enum { - CU_MEM_ADVISE_SET_READ_MOSTLY = 1, /**< Data will mostly be read and only occassionally be written to */ + CU_MEM_ADVISE_SET_READ_MOSTLY = 1, /**< Data will mostly be read and only occasionally be written to */ CU_MEM_ADVISE_UNSET_READ_MOSTLY = 2, /**< Undo the effect of ::CU_MEM_ADVISE_SET_READ_MOSTLY */ CU_MEM_ADVISE_SET_PREFERRED_LOCATION = 3, /**< Set the preferred location for the data as the specified device */ CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION = 4, /**< Clear the preferred location for the data */ @@ -935,7 +1008,7 @@ typedef enum CUmem_advise_enum { } CUmem_advise; typedef enum CUmem_range_attribute_enum { - CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY = 1, /**< Whether the range will mostly be read and only occassionally be written to */ + CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY = 1, /**< Whether the range will mostly be read and only occasionally be written to */ CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION = 2, /**< The preferred location of the range */ CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY = 3, /**< Memory range has ::CU_MEM_ADVISE_SET_ACCESSED_BY set for specified device */ CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION = 4 /**< The last location to which the range was prefetched */ @@ -957,7 +1030,7 @@ typedef enum CUjit_option_enum * IN: Specifies minimum number of threads per block to target compilation * for\n * OUT: Returns the number of threads the compiler actually targeted. - * This restricts the resource utilization fo the compiler (e.g. max + * This restricts the resource utilization of the compiler (e.g. max * registers) such that a block with the given number of threads should be * able to launch based on register limitations. Note, this option does not * currently take into account any other resource limitations, such as @@ -966,7 +1039,7 @@ typedef enum CUjit_option_enum * Option type: unsigned int\n * Applies to: compiler only */ - CU_JIT_THREADS_PER_BLOCK, + CU_JIT_THREADS_PER_BLOCK = 1, /** * Overwrites the option value with the total wall clock time, in @@ -974,7 +1047,7 @@ typedef enum CUjit_option_enum * Option type: float\n * Applies to: compiler and linker */ - CU_JIT_WALL_TIME, + CU_JIT_WALL_TIME = 2, /** * Pointer to a buffer in which to print any log messages @@ -983,7 +1056,7 @@ typedef enum CUjit_option_enum * Option type: char *\n * Applies to: compiler and linker */ - CU_JIT_INFO_LOG_BUFFER, + CU_JIT_INFO_LOG_BUFFER = 3, /** * IN: Log buffer size in bytes. Log messages will be capped at this size @@ -992,7 +1065,7 @@ typedef enum CUjit_option_enum * Option type: unsigned int\n * Applies to: compiler and linker */ - CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, + CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES = 4, /** * Pointer to a buffer in which to print any log messages that @@ -1001,7 +1074,7 @@ typedef enum CUjit_option_enum * Option type: char *\n * Applies to: compiler and linker */ - CU_JIT_ERROR_LOG_BUFFER, + CU_JIT_ERROR_LOG_BUFFER = 5, /** * IN: Log buffer size in bytes. Log messages will be capped at this size @@ -1010,7 +1083,7 @@ typedef enum CUjit_option_enum * Option type: unsigned int\n * Applies to: compiler and linker */ - CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, + CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES = 6, /** * Level of optimizations to apply to generated code (0 - 4), with 4 @@ -1018,7 +1091,7 @@ typedef enum CUjit_option_enum * Option type: unsigned int\n * Applies to: compiler only */ - CU_JIT_OPTIMIZATION_LEVEL, + CU_JIT_OPTIMIZATION_LEVEL = 7, /** * No option value required. Determines the target based on the current @@ -1026,7 +1099,7 @@ typedef enum CUjit_option_enum * Option type: No option value needed\n * Applies to: compiler and linker */ - CU_JIT_TARGET_FROM_CUCONTEXT, + CU_JIT_TARGET_FROM_CUCONTEXT = 8, /** * Target is chosen based on supplied ::CUjit_target. Cannot be @@ -1034,7 +1107,7 @@ typedef enum CUjit_option_enum * Option type: unsigned int for enumerated type ::CUjit_target\n * Applies to: compiler and linker */ - CU_JIT_TARGET, + CU_JIT_TARGET = 9, /** * Specifies choice of fallback strategy if matching cubin is not found. @@ -1043,7 +1116,7 @@ typedef enum CUjit_option_enum * Option type: unsigned int for enumerated type ::CUjit_fallback\n * Applies to: compiler only */ - CU_JIT_FALLBACK_STRATEGY, + CU_JIT_FALLBACK_STRATEGY = 10, /** * Specifies whether to create debug information in output (-g) @@ -1051,21 +1124,21 @@ typedef enum CUjit_option_enum * Option type: int\n * Applies to: compiler and linker */ - CU_JIT_GENERATE_DEBUG_INFO, + CU_JIT_GENERATE_DEBUG_INFO = 11, /** * Generate verbose log messages (0: false, default)\n * Option type: int\n * Applies to: compiler and linker */ - CU_JIT_LOG_VERBOSE, + CU_JIT_LOG_VERBOSE = 12, /** * Generate line number information (-lineinfo) (0: false, default)\n * Option type: int\n * Applies to: compiler only */ - CU_JIT_GENERATE_LINE_INFO, + CU_JIT_GENERATE_LINE_INFO = 13, /** * Specifies whether to enable caching explicitly (-dlcm) \n @@ -1073,19 +1146,24 @@ typedef enum CUjit_option_enum * Option type: unsigned int for enumerated type ::CUjit_cacheMode_enum\n * Applies to: compiler only */ - CU_JIT_CACHE_MODE, + CU_JIT_CACHE_MODE = 14, + + /** + * \deprecated + * This jit option is deprecated and should not be used. + */ + CU_JIT_NEW_SM3X_OPT = 15, /** - * The below jit options are used for internal purposes only, in this version of CUDA + * This jit option is used for internal purpose only. */ - CU_JIT_NEW_SM3X_OPT, - CU_JIT_FAST_COMPILE, + CU_JIT_FAST_COMPILE = 16, /** - * Array of device symbol names that will be relocated to the corresponing + * Array of device symbol names that will be relocated to the corresponding * host addresses stored in ::CU_JIT_GLOBAL_SYMBOL_ADDRESSES.\n * Must contain ::CU_JIT_GLOBAL_SYMBOL_COUNT entries.\n - * When loding a device module, driver will relocate all encountered + * When loading a device module, driver will relocate all encountered * unresolved symbols to the host addresses.\n * It is only allowed to register symbols that correspond to unresolved * global variables.\n @@ -1093,7 +1171,7 @@ typedef enum CUjit_option_enum * Option type: const char **\n * Applies to: dynamic linker only */ - CU_JIT_GLOBAL_SYMBOL_NAMES, + CU_JIT_GLOBAL_SYMBOL_NAMES = 17, /** * Array of host addresses that will be used to relocate corresponding @@ -1102,7 +1180,7 @@ typedef enum CUjit_option_enum * Option type: void **\n * Applies to: dynamic linker only */ - CU_JIT_GLOBAL_SYMBOL_ADDRESSES, + CU_JIT_GLOBAL_SYMBOL_ADDRESSES = 18, /** * Number of entries in ::CU_JIT_GLOBAL_SYMBOL_NAMES and @@ -1110,92 +1188,186 @@ typedef enum CUjit_option_enum * Option type: unsigned int\n * Applies to: dynamic linker only */ - CU_JIT_GLOBAL_SYMBOL_COUNT, + CU_JIT_GLOBAL_SYMBOL_COUNT = 19, /** - * Enable link-time optimization (-dlto) for device code (0: false, default).\n + * \deprecated + * Enable link-time optimization (-dlto) for device code (Disabled by default).\n * This option is not supported on 32-bit platforms.\n * Option type: int\n * Applies to: compiler and linker + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 */ - CU_JIT_LTO, + CU_JIT_LTO = 20, /** + * \deprecated * Control single-precision denormals (-ftz) support (0: false, default). * 1 : flushes denormal values to zero * 0 : preserves denormal values * Option type: int\n * Applies to: link-time optimization specified with CU_JIT_LTO + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 */ - CU_JIT_FTZ, + CU_JIT_FTZ = 21, /** + * \deprecated * Control single-precision floating-point division and reciprocals * (-prec-div) support (1: true, default). * 1 : Enables the IEEE round-to-nearest mode * 0 : Enables the fast approximation mode * Option type: int\n * Applies to: link-time optimization specified with CU_JIT_LTO + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 */ - CU_JIT_PREC_DIV, + CU_JIT_PREC_DIV = 22, /** + * \deprecated * Control single-precision floating-point square root * (-prec-sqrt) support (1: true, default). * 1 : Enables the IEEE round-to-nearest mode * 0 : Enables the fast approximation mode * Option type: int\n * Applies to: link-time optimization specified with CU_JIT_LTO + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 */ - CU_JIT_PREC_SQRT, + CU_JIT_PREC_SQRT = 23, /** + * \deprecated * Enable/Disable the contraction of floating-point multiplies * and adds/subtracts into floating-point multiply-add (-fma) * operations (1: Enable, default; 0: Disable). * Option type: int\n * Applies to: link-time optimization specified with CU_JIT_LTO + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 + */ + CU_JIT_FMA = 24, + + /** + * \deprecated + * Array of kernel names that should be preserved at link time while others + * can be removed.\n + * Must contain ::CU_JIT_REFERENCED_KERNEL_COUNT entries.\n + * Note that kernel names can be mangled by the compiler in which case the + * mangled name needs to be specified.\n + * Wildcard "*" can be used to represent zero or more characters instead of + * specifying the full or mangled name.\n + * It is important to note that the wildcard "*" is also added implicitly. + * For example, specifying "foo" will match "foobaz", "barfoo", "barfoobaz" and + * thus preserve all kernels with those names. This can be avoided by providing + * a more specific name like "barfoobaz".\n + * Option type: const char **\n + * Applies to: dynamic linker only + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 + */ + CU_JIT_REFERENCED_KERNEL_NAMES = 25, + + /** + * \deprecated + * Number of entries in ::CU_JIT_REFERENCED_KERNEL_NAMES array.\n + * Option type: unsigned int\n + * Applies to: dynamic linker only + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 + */ + CU_JIT_REFERENCED_KERNEL_COUNT = 26, + + /** + * \deprecated + * Array of variable names (__device__ and/or __constant__) that should be + * preserved at link time while others can be removed.\n + * Must contain ::CU_JIT_REFERENCED_VARIABLE_COUNT entries.\n + * Note that variable names can be mangled by the compiler in which case the + * mangled name needs to be specified.\n + * Wildcard "*" can be used to represent zero or more characters instead of + * specifying the full or mangled name.\n + * It is important to note that the wildcard "*" is also added implicitly. + * For example, specifying "foo" will match "foobaz", "barfoo", "barfoobaz" and + * thus preserve all variables with those names. This can be avoided by providing + * a more specific name like "barfoobaz".\n + * Option type: const char **\n + * Applies to: link-time optimization specified with CU_JIT_LTO + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 + */ + CU_JIT_REFERENCED_VARIABLE_NAMES = 27, + + /** + * \deprecated + * Number of entries in ::CU_JIT_REFERENCED_VARIABLE_NAMES array.\n + * Option type: unsigned int\n + * Applies to: link-time optimization specified with CU_JIT_LTO + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 + */ + CU_JIT_REFERENCED_VARIABLE_COUNT = 28, + + /** + * \deprecated + * This option serves as a hint to enable the JIT compiler/linker + * to remove constant (__constant__) and device (__device__) variables + * unreferenced in device code (Disabled by default).\n + * Note that host references to constant and device variables using APIs like + * ::cuModuleGetGlobal() with this option specified may result in undefined behavior unless + * the variables are explicitly specified using ::CU_JIT_REFERENCED_VARIABLE_NAMES.\n + * Option type: int\n + * Applies to: link-time optimization specified with CU_JIT_LTO + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 + */ + CU_JIT_OPTIMIZE_UNUSED_DEVICE_VARIABLES = 29, + + /** + * Generate position independent code (0: false)\n + * Option type: int\n + * Applies to: compiler only */ - CU_JIT_FMA, + CU_JIT_POSITION_INDEPENDENT_CODE = 30, CU_JIT_NUM_OPTIONS } CUjit_option; +/* + * Indicates that compute device class supports accelerated features. + */ +#define CU_COMPUTE_ACCELERATED_TARGET_BASE 0x10000 + /** * Online compilation targets */ typedef enum CUjit_target_enum { - - CU_TARGET_COMPUTE_20 = 20, /**< Compute device class 2.0 */ - CU_TARGET_COMPUTE_21 = 21, /**< Compute device class 2.1 */ - - CU_TARGET_COMPUTE_30 = 30, /**< Compute device class 3.0 */ CU_TARGET_COMPUTE_32 = 32, /**< Compute device class 3.2 */ CU_TARGET_COMPUTE_35 = 35, /**< Compute device class 3.5 */ CU_TARGET_COMPUTE_37 = 37, /**< Compute device class 3.7 */ - - CU_TARGET_COMPUTE_50 = 50, /**< Compute device class 5.0 */ CU_TARGET_COMPUTE_52 = 52, /**< Compute device class 5.2 */ CU_TARGET_COMPUTE_53 = 53, /**< Compute device class 5.3 */ - - CU_TARGET_COMPUTE_60 = 60, /**< Compute device class 6.0.*/ CU_TARGET_COMPUTE_61 = 61, /**< Compute device class 6.1.*/ CU_TARGET_COMPUTE_62 = 62, /**< Compute device class 6.2.*/ - - CU_TARGET_COMPUTE_70 = 70, /**< Compute device class 7.0.*/ CU_TARGET_COMPUTE_72 = 72, /**< Compute device class 7.2.*/ - CU_TARGET_COMPUTE_75 = 75, /**< Compute device class 7.5.*/ - CU_TARGET_COMPUTE_80 = 80, /**< Compute device class 8.0.*/ - CU_TARGET_COMPUTE_86 = 86 /**< Compute device class 8.6.*/ + CU_TARGET_COMPUTE_86 = 86, /**< Compute device class 8.6.*/ + CU_TARGET_COMPUTE_87 = 87, /**< Compute device class 8.7.*/ + CU_TARGET_COMPUTE_89 = 89, /**< Compute device class 8.9.*/ + CU_TARGET_COMPUTE_90 = 90, /**< Compute device class 9.0.*/ + /**< Compute device class 9.0. with accelerated features.*/ + CU_TARGET_COMPUTE_90A = CU_COMPUTE_ACCELERATED_TARGET_BASE + CU_TARGET_COMPUTE_90 } CUjit_target; /** @@ -1234,33 +1406,36 @@ typedef enum CUjitInputType_enum * PTX source code\n * Applicable options: PTX compiler options */ - CU_JIT_INPUT_PTX, + CU_JIT_INPUT_PTX = 1, /** * Bundle of multiple cubins and/or PTX of some device code\n * Applicable options: PTX compiler options, ::CU_JIT_FALLBACK_STRATEGY */ - CU_JIT_INPUT_FATBINARY, + CU_JIT_INPUT_FATBINARY = 2, /** * Host object with embedded device code\n * Applicable options: PTX compiler options, ::CU_JIT_FALLBACK_STRATEGY */ - CU_JIT_INPUT_OBJECT, + CU_JIT_INPUT_OBJECT = 3, /** * Archive of host objects with embedded device code\n * Applicable options: PTX compiler options, ::CU_JIT_FALLBACK_STRATEGY */ - CU_JIT_INPUT_LIBRARY, + CU_JIT_INPUT_LIBRARY = 4, /** + * \deprecated * High-level intermediate code for link-time optimization\n * Applicable options: NVVM compiler options, PTX compiler options + * + * Only valid with LTO-IR compiled with toolkits prior to CUDA 12.0 */ - CU_JIT_INPUT_NVVM, + CU_JIT_INPUT_NVVM = 5, - CU_JIT_NUM_INPUT_TYPES + CU_JIT_NUM_INPUT_TYPES = 6 } CUjitInputType; typedef struct CUlinkState_st *CUlinkState; @@ -1315,7 +1490,7 @@ typedef enum CUlimit_enum { * Resource types */ typedef enum CUresourcetype_enum { - CU_RESOURCE_TYPE_ARRAY = 0x00, /**< Array resoure */ + CU_RESOURCE_TYPE_ARRAY = 0x00, /**< Array resource */ CU_RESOURCE_TYPE_MIPMAPPED_ARRAY = 0x01, /**< Mipmapped array resource */ CU_RESOURCE_TYPE_LINEAR = 0x02, /**< Linear resource */ CU_RESOURCE_TYPE_PITCH2D = 0x03 /**< Pitch 2D resource */ @@ -1361,6 +1536,9 @@ typedef struct CUaccessPolicyWindow_st { CUaccessProperty hitProp; /**< ::CUaccessProperty set for hit. */ CUaccessProperty missProp; /**< ::CUaccessProperty set for miss. Must be either NORMAL or STREAMING */ } CUaccessPolicyWindow_v1; +/** + * Access policy window + */ typedef CUaccessPolicyWindow_v1 CUaccessPolicyWindow; /** @@ -1378,7 +1556,24 @@ typedef struct CUDA_KERNEL_NODE_PARAMS_st { void **kernelParams; /**< Array of pointers to kernel parameters */ void **extra; /**< Extra options */ } CUDA_KERNEL_NODE_PARAMS_v1; -typedef CUDA_KERNEL_NODE_PARAMS_v1 CUDA_KERNEL_NODE_PARAMS; + +/** + * GPU kernel node parameters + */typedef struct CUDA_KERNEL_NODE_PARAMS_v2_st { + CUfunction func; /**< Kernel to launch */ + unsigned int gridDimX; /**< Width of grid in blocks */ + unsigned int gridDimY; /**< Height of grid in blocks */ + unsigned int gridDimZ; /**< Depth of grid in blocks */ + unsigned int blockDimX; /**< X dimension of each thread block */ + unsigned int blockDimY; /**< Y dimension of each thread block */ + unsigned int blockDimZ; /**< Z dimension of each thread block */ + unsigned int sharedMemBytes; /**< Dynamic shared-memory size per thread block in bytes */ + void **kernelParams; /**< Array of pointers to kernel parameters */ + void **extra; /**< Extra options */ + CUkernel kern; /**< Kernel to launch, will only be referenced if func is NULL */ + CUcontext ctx; /**< Context for the kernel task to run in. The value NULL will indicate the current context should be used by the api. This field is ignored if func is set. */ +} CUDA_KERNEL_NODE_PARAMS_v2; +typedef CUDA_KERNEL_NODE_PARAMS_v2 CUDA_KERNEL_NODE_PARAMS; /** * Memset node parameters @@ -1417,9 +1612,33 @@ typedef enum CUgraphNodeType_enum { CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL = 8, /**< External semaphore signal node */ CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT = 9, /**< External semaphore wait node */ CU_GRAPH_NODE_TYPE_MEM_ALLOC = 10,/**< Memory Allocation Node */ - CU_GRAPH_NODE_TYPE_MEM_FREE = 11 /**< Memory Free Node */ + CU_GRAPH_NODE_TYPE_MEM_FREE = 11,/**< Memory Free Node */ + CU_GRAPH_NODE_TYPE_BATCH_MEM_OP = 12 /**< Batch MemOp Node */ } CUgraphNodeType; +/** + * Graph instantiation results +*/ +typedef enum CUgraphInstantiateResult_enum +{ + CUDA_GRAPH_INSTANTIATE_SUCCESS = 0, /**< Instantiation succeeded */ + CUDA_GRAPH_INSTANTIATE_ERROR = 1, /**< Instantiation failed for an unexpected reason which is described in the return value of the function */ + CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE = 2, /**< Instantiation failed due to invalid structure, such as cycles */ + CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED = 3, /**< Instantiation for device launch failed because the graph contained an unsupported operation */ + CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED = 4 /**< Instantiation for device launch failed due to the nodes belonging to different contexts */ +} CUgraphInstantiateResult; + +/** + * Graph instantiation parameters + */ +typedef struct CUDA_GRAPH_INSTANTIATE_PARAMS_st +{ + cuuint64_t flags; /**< Instantiation flags */ + CUstream hUploadStream; /**< Upload stream */ + CUgraphNode hErrNode_out; /**< The node which caused instantiation to fail, if any */ + CUgraphInstantiateResult result_out; /**< Whether instantiation was successful. If it failed, the reason why */ +} CUDA_GRAPH_INSTANTIATE_PARAMS; + typedef enum CUsynchronizationPolicy_enum { CU_SYNC_POLICY_AUTO = 1, CU_SYNC_POLICY_SPIN = 2, @@ -1428,20 +1647,124 @@ typedef enum CUsynchronizationPolicy_enum { } CUsynchronizationPolicy; /** - * Graph kernel node Attributes - */ -typedef enum CUkernelNodeAttrID_enum { - CU_KERNEL_NODE_ATTRIBUTE_ACCESS_POLICY_WINDOW = 1, /**< Identifier for ::CUkernelNodeAttrValue::accessPolicyWindow. */ - CU_KERNEL_NODE_ATTRIBUTE_COOPERATIVE = 2 /**< Allows a kernel node to be cooperative (see ::cuLaunchCooperativeKernel). */ -} CUkernelNodeAttrID; + * Cluster scheduling policies. These may be passed to ::cuFuncSetAttribute or ::cuKernelSetAttribute + */ +typedef enum CUclusterSchedulingPolicy_enum { + CU_CLUSTER_SCHEDULING_POLICY_DEFAULT = 0, /**< the default policy */ + CU_CLUSTER_SCHEDULING_POLICY_SPREAD = 1, /**< spread the blocks within a cluster to the SMs */ + CU_CLUSTER_SCHEDULING_POLICY_LOAD_BALANCING = 2 /**< allow the hardware to load-balance the blocks in a cluster to the SMs */ +} CUclusterSchedulingPolicy; + +typedef enum CUlaunchMemSyncDomain_enum { + CU_LAUNCH_MEM_SYNC_DOMAIN_DEFAULT = 0, + CU_LAUNCH_MEM_SYNC_DOMAIN_REMOTE = 1 +} CUlaunchMemSyncDomain; + +typedef struct CUlaunchMemSyncDomainMap_st { + unsigned char default_; + unsigned char remote; +} CUlaunchMemSyncDomainMap; + +typedef enum CUlaunchAttributeID_enum { + CU_LAUNCH_ATTRIBUTE_IGNORE = 0 /**< Ignored entry, for convenient composition */ + , CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW = 1 /**< Valid for streams, graph nodes, launches. */ + , CU_LAUNCH_ATTRIBUTE_COOPERATIVE = 2 /**< Valid for graph nodes, launches. */ + , CU_LAUNCH_ATTRIBUTE_SYNCHRONIZATION_POLICY = 3 /**< Valid for streams. */ + , CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION = 4 /**< Valid for graph nodes, launches. */ + , CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE = 5 /**< Valid for graph nodes, launches. */ + , CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION = 6 /**< Valid for launches. Setting + programmaticStreamSerializationAllowed to non-0 + signals that the kernel will use programmatic + means to resolve its stream dependency, so that + the CUDA runtime should opportunistically allow + the grid's execution to overlap with the previous + kernel in the stream, if that kernel requests the + overlap. The dependent launches can choose to wait + on the dependency using the programmatic sync + (cudaGridDependencySynchronize() or equivalent PTX + instructions). */ + , CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_EVENT = 7 /**< Valid for launches. Event recorded through this + launch attribute is guaranteed to only trigger + after all block in the associated kernel trigger + the event. A block can trigger the event through + PTX launchdep.release or CUDA builtin function + cudaTriggerProgrammaticLaunchCompletion(). A + trigger can also be inserted at the beginning of + each block's execution if triggerAtBlockStart is + set to non-0. The dependent launches can choose to + wait on the dependency using the programmatic sync + (cudaGridDependencySynchronize() or equivalent PTX + instructions). Note that dependents (including the + CPU thread calling cuEventSynchronize()) are not + guaranteed to observe the release precisely when + it is released. For example, cuEventSynchronize() + may only observe the event trigger long after the + associated kernel has completed. This recording + type is primarily meant for establishing + programmatic dependency between device tasks. The + event supplied must not be an interprocess or + interop event. The event must disable timing (i.e. + created with ::CU_EVENT_DISABLE_TIMING flag set). + */ + , CU_LAUNCH_ATTRIBUTE_PRIORITY = 8 /**< Valid for streams, graph nodes, launches. */ + , CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN_MAP = 9 + , CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN = 10 +#ifdef __CUDA_API_VERSION_INTERNAL + , CU_LAUNCH_ATTRIBUTE_MAX +#endif +} CUlaunchAttributeID; -/** - * Graph kernel node attributes union, used with ::cuKernelNodeSetAttribute/::cuKernelNodeGetAttribute - */ -typedef union CUkernelNodeAttrValue_union { - CUaccessPolicyWindow accessPolicyWindow; /**< Attribute ::CUaccessPolicyWindow. */ - int cooperative; /**< Nonzero indicates a cooperative kernel (see ::cuLaunchCooperativeKernel). */ -} CUkernelNodeAttrValue_v1; +typedef union CUlaunchAttributeValue_union { + char pad[64]; /**< Pad to 64 bytes */ + CUaccessPolicyWindow accessPolicyWindow; /**< Attribute ::CUaccessPolicyWindow. */ + int cooperative; /**< Nonzero indicates a cooperative kernel (see ::cuLaunchCooperativeKernel). */ + CUsynchronizationPolicy syncPolicy; /**< ::CUsynchronizationPolicy for work queued up in this stream */ + struct { + unsigned int x; + unsigned int y; + unsigned int z; + } clusterDim; /**< Cluster dimensions for the kernel node. */ + CUclusterSchedulingPolicy clusterSchedulingPolicyPreference; /**< Cluster scheduling policy preference for the kernel node. */ + int programmaticStreamSerializationAllowed; + struct { + CUevent event; + int flags; /* Does not accept ::CU_EVENT_RECORD_EXTERNAL */ + int triggerAtBlockStart; + } programmaticEvent; + int priority; /**< Execution priority of the kernel. */ + CUlaunchMemSyncDomainMap memSyncDomainMap; + CUlaunchMemSyncDomain memSyncDomain; +} CUlaunchAttributeValue; + +typedef struct CUlaunchAttribute_st { + CUlaunchAttributeID id; + char pad[8 - sizeof(CUlaunchAttributeID)]; + CUlaunchAttributeValue value; +} CUlaunchAttribute; + +typedef struct CUlaunchConfig_st { + unsigned int gridDimX; /**< Width of grid in blocks */ + unsigned int gridDimY; /**< Height of grid in blocks */ + unsigned int gridDimZ; /**< Depth of grid in blocks */ + unsigned int blockDimX; /**< X dimension of each thread block */ + unsigned int blockDimY; /**< Y dimension of each thread block */ + unsigned int blockDimZ; /**< Z dimension of each thread block */ + unsigned int sharedMemBytes; /**< Dynamic shared-memory size per thread block in bytes */ + CUstream hStream; /**< Stream identifier */ + CUlaunchAttribute *attrs; /**< nullable if numAttrs == 0 */ + unsigned int numAttrs; /**< number of attributes populated in attrs */ +} CUlaunchConfig; + +typedef CUlaunchAttributeID CUkernelNodeAttrID; +#define CU_KERNEL_NODE_ATTRIBUTE_ACCESS_POLICY_WINDOW CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW +#define CU_KERNEL_NODE_ATTRIBUTE_COOPERATIVE CU_LAUNCH_ATTRIBUTE_COOPERATIVE +#define CU_KERNEL_NODE_ATTRIBUTE_CLUSTER_DIMENSION CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION +#define CU_KERNEL_NODE_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE +#define CU_KERNEL_NODE_ATTRIBUTE_PRIORITY CU_LAUNCH_ATTRIBUTE_PRIORITY +#define CU_KERNEL_NODE_ATTRIBUTE_MEM_SYNC_DOMAIN_MAP CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN_MAP +#define CU_KERNEL_NODE_ATTRIBUTE_MEM_SYNC_DOMAIN CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN + +typedef CUlaunchAttributeValue CUkernelNodeAttrValue_v1; typedef CUkernelNodeAttrValue_v1 CUkernelNodeAttrValue; /** @@ -1464,21 +1787,14 @@ typedef enum CUstreamCaptureMode_enum { CU_STREAM_CAPTURE_MODE_RELAXED = 2 } CUstreamCaptureMode; -/** - * Stream Attributes - */ -typedef enum CUstreamAttrID_enum { - CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW = 1, /**< Identifier for ::CUstreamAttrValue::accessPolicyWindow. */ - CU_STREAM_ATTRIBUTE_SYNCHRONIZATION_POLICY = 3 /**< ::CUsynchronizationPolicy for work queued up in this stream */ -} CUstreamAttrID; +typedef CUlaunchAttributeID CUstreamAttrID; +#define CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW +#define CU_STREAM_ATTRIBUTE_SYNCHRONIZATION_POLICY CU_LAUNCH_ATTRIBUTE_SYNCHRONIZATION_POLICY +#define CU_STREAM_ATTRIBUTE_PRIORITY CU_LAUNCH_ATTRIBUTE_PRIORITY +#define CU_STREAM_ATTRIBUTE_MEM_SYNC_DOMAIN_MAP CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN_MAP +#define CU_STREAM_ATTRIBUTE_MEM_SYNC_DOMAIN CU_LAUNCH_ATTRIBUTE_MEM_SYNC_DOMAIN -/** - * Stream attributes union, used with ::cuStreamSetAttribute/::cuStreamGetAttribute - */ -typedef union CUstreamAttrValue_union { - CUaccessPolicyWindow accessPolicyWindow; /**< Attribute ::CUaccessPolicyWindow. */ - CUsynchronizationPolicy syncPolicy; /**< Value for ::CU_STREAM_ATTRIBUTE_SYNCHRONIZATION_POLICY. */ -} CUstreamAttrValue_v1; +typedef CUlaunchAttributeValue CUstreamAttrValue_v1; typedef CUstreamAttrValue_v1 CUstreamAttrValue; /** @@ -1490,6 +1806,15 @@ typedef enum CUdriverProcAddress_flags_enum { CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM = 1 << 1 /**< Search for per-thread versions of driver symbols. */ } CUdriverProcAddress_flags; +/** + * Flags to indicate search status. For more details see ::cuGetProcAddress + */ +typedef enum CUdriverProcAddressQueryResult_enum { + CU_GET_PROC_ADDRESS_SUCCESS = 0, /**< Symbol was succesfully found */ + CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND = 1, /**< Symbol was not found in search */ + CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT = 2 /**< Symbol was found but version supplied was not sufficient */ +} CUdriverProcAddressQueryResult; + /** * Execution Affinity Types */ @@ -1515,8 +1840,40 @@ typedef struct CUexecAffinityParam_st { CUexecAffinitySmCount smCount; /** Value for ::CU_EXEC_AFFINITY_TYPE_SM_COUNT */ } param; } CUexecAffinityParam_v1; +/** + * Execution Affinity Parameters + */ typedef CUexecAffinityParam_v1 CUexecAffinityParam; +/** + * Library options to be specified with ::cuLibraryLoadData() or ::cuLibraryLoadFromFile() + */ +typedef enum CUlibraryOption_enum +{ + CU_LIBRARY_HOST_UNIVERSAL_FUNCTION_AND_DATA_TABLE = 0, + + /** + * Specifes that the argument \p code passed to ::cuLibraryLoadData() will be preserved. + * Specifying this option will let the driver know that \p code can be accessed at any point + * until ::cuLibraryUnload(). The default behavior is for the driver to allocate and + * maintain its own copy of \p code. Note that this is only a memory usage optimization + * hint and the driver can choose to ignore it if required. + * Specifying this option with ::cuLibraryLoadFromFile() is invalid and + * will return ::CUDA_ERROR_INVALID_VALUE. + */ + CU_LIBRARY_BINARY_IS_PRESERVED = 1, + + CU_LIBRARY_NUM_OPTIONS +} CUlibraryOption; + +typedef struct CUlibraryHostUniversalFunctionAndDataTable_st +{ + void *functionTable; + size_t functionWindowSize; + void *dataTable; + size_t dataWindowSize; +} CUlibraryHostUniversalFunctionAndDataTable; + /** * Error codes */ @@ -1587,6 +1944,13 @@ typedef enum cudaError_enum { */ CUDA_ERROR_STUB_LIBRARY = 34, + /** + * This indicates that requested CUDA device is unavailable at the current + * time. Devices are often unavailable due to use of + * ::CU_COMPUTEMODE_EXCLUSIVE_PROCESS or ::CU_COMPUTEMODE_PROHIBITED. + */ + CUDA_ERROR_DEVICE_UNAVAILABLE = 46, + /** * This indicates that no CUDA-capable devices were detected by the installed * CUDA driver. @@ -1744,6 +2108,12 @@ typedef enum cudaError_enum { */ CUDA_ERROR_UNSUPPORTED_EXEC_AFFINITY = 224, + /** + * This indicates that the code to be compiled by the PTX JIT contains + * unsupported call to cudaDeviceSynchronize. + */ + CUDA_ERROR_UNSUPPORTED_DEVSIDE_SYNC = 225, + /** * This indicates that the device kernel source is invalid. This includes * compilation/linker errors encountered in device code or user error. @@ -2014,6 +2384,21 @@ typedef enum cudaError_enum { */ CUDA_ERROR_MPS_MAX_CONNECTIONS_REACHED = 809, + /** + * This error indicates that the MPS client has been terminated by the server. To continue using CUDA, the process must be terminated and relaunched. + */ + CUDA_ERROR_MPS_CLIENT_TERMINATED = 810, + + /** + * This error indicates that the module is using CUDA Dynamic Parallelism, but the current configuration, like MPS, does not support it. + */ + CUDA_ERROR_CDP_NOT_SUPPORTED = 811, + + /** + * This error indicates that a module contains an unsupported interaction between different versions of CUDA Dynamic Parallelism. + */ + CUDA_ERROR_CDP_VERSION_MISMATCH = 812, + /** * This error indicates that the operation is not permitted when * the stream is capturing. @@ -2090,12 +2475,10 @@ typedef enum cudaError_enum { */ CUDA_ERROR_EXTERNAL_DEVICE = 911, - - - - - - + /** + * Indicates a kernel launch error due to cluster misconfiguration. + */ + CUDA_ERROR_INVALID_CLUSTER_SIZE = 912, /** * This indicates that an unknown internal error has occurred. @@ -2114,17 +2497,6 @@ typedef enum CUdevice_P2PAttribute_enum { CU_DEVICE_P2P_ATTRIBUTE_CUDA_ARRAY_ACCESS_SUPPORTED = 0x04 /**< Accessing CUDA arrays over the link supported */ } CUdevice_P2PAttribute; - - - - - - - - - - - /** * CUDA stream callback * \param hStream The stream the callback was added to, as passed to ::cuStreamAddCallback. May be NULL. @@ -2355,7 +2727,6 @@ typedef struct CUDA_ARRAY_SPARSE_PROPERTIES_st { } CUDA_ARRAY_SPARSE_PROPERTIES_v1; typedef CUDA_ARRAY_SPARSE_PROPERTIES_v1 CUDA_ARRAY_SPARSE_PROPERTIES; - /** * CUDA array memory requirements */ @@ -2366,7 +2737,6 @@ typedef struct CUDA_ARRAY_MEMORY_REQUIREMENTS_st { } CUDA_ARRAY_MEMORY_REQUIREMENTS_v1; typedef CUDA_ARRAY_MEMORY_REQUIREMENTS_v1 CUDA_ARRAY_MEMORY_REQUIREMENTS; - /** * CUDA Resource descriptor */ @@ -2480,6 +2850,79 @@ typedef struct CUDA_RESOURCE_VIEW_DESC_st } CUDA_RESOURCE_VIEW_DESC_v1; typedef CUDA_RESOURCE_VIEW_DESC_v1 CUDA_RESOURCE_VIEW_DESC; +/** + * Size of tensor map descriptor + */ +#define CU_TENSOR_MAP_NUM_QWORDS 16 + +/** + * Tensor map descriptor. Requires compiler support for aligning to 64 bytes. + */ +typedef struct CUtensorMap_st { +#if __cplusplus >= 201103L + alignas(64) +#elif __STDC_VERSION__ >= 201112L + _Alignas(64) +#endif + cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; +} CUtensorMap; + +/** + * Tensor map data type + */ +typedef enum CUtensorMapDataType_enum { + CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0, + CU_TENSOR_MAP_DATA_TYPE_UINT16, + CU_TENSOR_MAP_DATA_TYPE_UINT32, + CU_TENSOR_MAP_DATA_TYPE_INT32, + CU_TENSOR_MAP_DATA_TYPE_UINT64, + CU_TENSOR_MAP_DATA_TYPE_INT64, + CU_TENSOR_MAP_DATA_TYPE_FLOAT16, + CU_TENSOR_MAP_DATA_TYPE_FLOAT32, + CU_TENSOR_MAP_DATA_TYPE_FLOAT64, + CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ, + CU_TENSOR_MAP_DATA_TYPE_TFLOAT32, + CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ +} CUtensorMapDataType; + +/** + * Tensor map interleave layout type + */ +typedef enum CUtensorMapInterleave_enum { + CU_TENSOR_MAP_INTERLEAVE_NONE = 0, + CU_TENSOR_MAP_INTERLEAVE_16B, + CU_TENSOR_MAP_INTERLEAVE_32B +} CUtensorMapInterleave; + +/** + * Tensor map swizzling mode of shared memory banks + */ +typedef enum CUtensorMapSwizzle_enum { + CU_TENSOR_MAP_SWIZZLE_NONE = 0, + CU_TENSOR_MAP_SWIZZLE_32B, + CU_TENSOR_MAP_SWIZZLE_64B, + CU_TENSOR_MAP_SWIZZLE_128B +} CUtensorMapSwizzle; + +/** + * Tensor map L2 promotion type + */ +typedef enum CUtensorMapL2promotion_enum { + CU_TENSOR_MAP_L2_PROMOTION_NONE = 0, + CU_TENSOR_MAP_L2_PROMOTION_L2_64B, + CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CU_TENSOR_MAP_L2_PROMOTION_L2_256B +} CUtensorMapL2promotion; + +/** + * Tensor map out-of-bounds fill type + */ +typedef enum CUtensorMapFloatOOBfill_enum { + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE = 0, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA +} CUtensorMapFloatOOBfill; + /** * GPU Direct v3 tokens */ @@ -2964,6 +3407,15 @@ typedef enum CUmemAllocationGranularity_flags_enum { CU_MEM_ALLOC_GRANULARITY_RECOMMENDED = 0x1 /**< Recommended granularity for allocation for best performance */ } CUmemAllocationGranularity_flags; +/** +* Specifies the handle type for address range +*/ +typedef enum CUmemRangeHandleType_enum +{ + CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD = 0x1, + CU_MEM_RANGE_HANDLE_TYPE_MAX = 0x7FFFFFFF +} CUmemRangeHandleType; + /** * Sparse subresource types */ @@ -3066,9 +3518,9 @@ typedef struct CUmemAllocationProp_st { CUmemLocation location; /** * Windows-specific POBJECT_ATTRIBUTES required when - * ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This object atributes structure + * ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This object attributes structure * includes security attributes that define - * the scope of which exported allocations may be tranferred to other + * the scope of which exported allocations may be transferred to other * processes. In all other cases, this field is required to be zero. */ void *win32HandleMetaData; @@ -3092,6 +3544,40 @@ typedef struct CUmemAllocationProp_st { } CUmemAllocationProp_v1; typedef CUmemAllocationProp_v1 CUmemAllocationProp; +/** +* Flags for querying different granularities for a multicast object +*/ +typedef enum CUmulticastGranularity_flags_enum { + CU_MULTICAST_GRANULARITY_MINIMUM = 0x0, /**< Minimum required granularity */ + CU_MULTICAST_GRANULARITY_RECOMMENDED = 0x1 /**< Recommended granularity for best performance */ +} CUmulticastGranularity_flags; + +/** +* Specifies the properties for a multicast object. +*/ +typedef struct CUmulticastObjectProp_st { + /** + * The number of devices in the multicast team that will bind memory to this + * object + */ + unsigned int numDevices; + /** + * The maximum amount of memory that can be bound to this multicast object + * per device + */ + size_t size; + /** + * Bitmask of exportable handle types (see ::CUmemAllocationHandleType) for + * this object + */ + unsigned long long handleTypes; + /** + * Flags for future use, must be zero now + */ + unsigned long long flags; +} CUmulticastObjectProp_v1; +typedef CUmulticastObjectProp_v1 CUmulticastObjectProp; + /** * Memory access descriptor */ @@ -3101,6 +3587,9 @@ typedef struct CUmemAccessDesc_st { } CUmemAccessDesc_v1; typedef CUmemAccessDesc_v1 CUmemAccessDesc; +/** + * CUDA Graph Update error types + */ typedef enum CUgraphExecUpdateResult_enum { CU_GRAPH_EXEC_UPDATE_SUCCESS = 0x0, /**< The update succeeded */ CU_GRAPH_EXEC_UPDATE_ERROR = 0x1, /**< The update failed for an unexpected reason which is described in the return value of the function */ @@ -3113,6 +3602,29 @@ typedef enum CUgraphExecUpdateResult_enum { CU_GRAPH_EXEC_UPDATE_ERROR_ATTRIBUTES_CHANGED = 0x8 /**< The update failed because the node attributes changed in a way that is not supported */ } CUgraphExecUpdateResult; +/** + * Result information returned by cuGraphExecUpdate + */ +typedef struct CUgraphExecUpdateResultInfo_st { + /** + * Gives more specific detail when a cuda graph update fails. + */ + CUgraphExecUpdateResult result; + + /** + * The "to node" of the error edge when the topologies do not match. + * The error node when the error is associated with a specific node. + * NULL when the error is generic. + */ + CUgraphNode errorNode; + + /** + * The from node of error edge when the topologies do not match. Otherwise NULL. + */ + CUgraphNode errorFromNode; +} CUgraphExecUpdateResultInfo_v1; +typedef CUgraphExecUpdateResultInfo_v1 CUgraphExecUpdateResultInfo; + /** * CUDA memory pool attributes */ @@ -3189,7 +3701,7 @@ typedef struct CUmemPoolProps_st { /** * Windows-specific LPSECURITYATTRIBUTES required when * ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This security attribute defines - * the scope of which exported allocations may be tranferred to other + * the scope of which exported allocations may be transferred to other * processes. In all other cases, this field is required to be zero. */ void *win32SecurityAttributes; @@ -3313,14 +3825,12 @@ typedef enum CUgraphMem_attribute_enum { */ #define CUDA_ARRAY3D_SPARSE 0x40 - /** * This flag if set indicates that the CUDA array or CUDA mipmapped array * will allow deferred memory mapping */ #define CUDA_ARRAY3D_DEFERRED_MAPPING 0x80 - /** * Override the texref format with a format inferred from the array. * Flag for ::cuTexRefSetArray() @@ -3358,11 +3868,21 @@ typedef enum CUgraphMem_attribute_enum { */ #define CU_TRSF_SEAMLESS_CUBEMAP 0x40 +/** + * C++ compile time constant for CU_LAUNCH_PARAM_END + */ +#define CU_LAUNCH_PARAM_END_AS_INT 0x00 + /** * End of array terminator for the \p extra parameter to * ::cuLaunchKernel */ -#define CU_LAUNCH_PARAM_END ((void*)0x00) +#define CU_LAUNCH_PARAM_END ((void*)CU_LAUNCH_PARAM_END_AS_INT) + +/** + * C++ compile time constant for CU_LAUNCH_PARAM_BUFFER_POINTER + */ +#define CU_LAUNCH_PARAM_BUFFER_POINTER_AS_INT 0x01 /** * Indicator that the next value in the \p extra parameter to @@ -3373,7 +3893,12 @@ typedef enum CUgraphMem_attribute_enum { * \p extra array, then ::CU_LAUNCH_PARAM_BUFFER_POINTER will have no * effect. */ -#define CU_LAUNCH_PARAM_BUFFER_POINTER ((void*)0x01) +#define CU_LAUNCH_PARAM_BUFFER_POINTER ((void*)CU_LAUNCH_PARAM_BUFFER_POINTER_AS_INT) + +/** + * C++ compile time constant for CU_LAUNCH_PARAM_BUFFER_SIZE + */ +#define CU_LAUNCH_PARAM_BUFFER_SIZE_AS_INT 0x02 /** * Indicator that the next value in the \p extra parameter to @@ -3383,7 +3908,7 @@ typedef enum CUgraphMem_attribute_enum { * in the \p extra array if the value associated with * ::CU_LAUNCH_PARAM_BUFFER_SIZE is not zero. */ -#define CU_LAUNCH_PARAM_BUFFER_SIZE ((void*)0x02) +#define CU_LAUNCH_PARAM_BUFFER_SIZE ((void*)CU_LAUNCH_PARAM_BUFFER_SIZE_AS_INT) /** * For texture references loaded into the module, use default texunit from @@ -3437,19 +3962,21 @@ typedef enum CUflushGPUDirectRDMAWritesTarget_enum { * The additional write options for ::cuGraphDebugDotPrint */ typedef enum CUgraphDebugDot_flags_enum { - CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE = 1<<0, /** Output all debug data as if every debug flag is enabled */ - CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES = 1<<1, /** Use CUDA Runtime structures for output */ - CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS = 1<<2, /** Adds CUDA_KERNEL_NODE_PARAMS values to output */ - CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS = 1<<3, /** Adds CUDA_MEMCPY3D values to output */ - CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS = 1<<4, /** Adds CUDA_MEMSET_NODE_PARAMS values to output */ - CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS = 1<<5, /** Adds CUDA_HOST_NODE_PARAMS values to output */ - CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS = 1<<6, /** Adds CUevent handle from record and wait nodes to output */ - CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS = 1<<7, /** Adds CUDA_EXT_SEM_SIGNAL_NODE_PARAMS values to output */ - CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS = 1<<8, /** Adds CUDA_EXT_SEM_WAIT_NODE_PARAMS values to output */ - CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES = 1<<9, /** Adds CUkernelNodeAttrValue values to output */ - CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES = 1<<10, /** Adds node handles and every kernel function handle to output */ - CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS = 1<<11, /** Adds memory alloc node parameters to output */ - CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS = 1<<12 /** Adds memory free node parameters to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE = 1<<0, /**< Output all debug data as if every debug flag is enabled */ + CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES = 1<<1, /**< Use CUDA Runtime structures for output */ + CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS = 1<<2, /**< Adds CUDA_KERNEL_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS = 1<<3, /**< Adds CUDA_MEMCPY3D values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS = 1<<4, /**< Adds CUDA_MEMSET_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS = 1<<5, /**< Adds CUDA_HOST_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS = 1<<6, /**< Adds CUevent handle from record and wait nodes to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS = 1<<7, /**< Adds CUDA_EXT_SEM_SIGNAL_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS = 1<<8, /**< Adds CUDA_EXT_SEM_WAIT_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES = 1<<9, /**< Adds CUkernelNodeAttrValue values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES = 1<<10, /**< Adds node handles and every kernel function handle to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS = 1<<11, /**< Adds memory alloc node parameters to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS = 1<<12, /**< Adds memory free node parameters to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS = 1<<13 /**< Adds batch mem op node parameters to output */ + , CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO = 1<<14 /**< Adds edge numbering information */ } CUgraphDebugDot_flags; /** @@ -3470,54 +3997,17 @@ typedef enum CUuserObjectRetain_flags_enum { * Flags for instantiating a graph */ typedef enum CUgraphInstantiate_flags_enum { - CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH = 1 /**< Automatically free memory allocated in a graph before relaunching. */ + CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH = 1 /**< Automatically free memory allocated in a graph before relaunching. */ + , CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD = 2 /**< Automatically upload the graph after instantiaton. */ + , CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH = 4 /**< Instantiate the graph to be launchable from the device. */ + , CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY = 8 /**< Run the graph using the per-node priority attributes rather than the + priority of the stream it is launched into. */ } CUgraphInstantiate_flags; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +typedef enum cuDeviceIsNumaNode_enum { + CU_DEVICE_NOT_NUMA_NODE = 0x01, + CU_DEVICE_IS_NUMA_NODE = 0x02, +} cuDeviceIsNumaNode; /** @} */ /* END CUDA_TYPES */ @@ -3603,9 +4093,8 @@ CUresult CUDAAPI cuGetErrorName(CUresult error, const char **pStr); /** * \brief Initialize the CUDA driver API - * * Initializes the driver API and must be called before any other function from - * the driver API. Currently, the \p Flags parameter must be 0. If ::cuInit() + * the driver API in the current process. Currently, the \p Flags parameter must be 0. If ::cuInit() * has not been called, any function from the driver API will return * ::CUDA_ERROR_NOT_INITIALIZED. * @@ -3732,7 +4221,7 @@ CUresult CUDAAPI cuDeviceGet(CUdevice *device, int ordinal); CUresult CUDAAPI cuDeviceGetCount(int *count); /** - * \brief Returns an identifer string for the device + * \brief Returns an identifier string for the device * * Returns an ASCII string identifying the device \p dev in the NULL-terminated * string pointed to by \p name. \p len specifies the maximum length of the @@ -3769,7 +4258,7 @@ CUresult CUDAAPI cuDeviceGetName(char *name, int len, CUdevice dev); * Note there is a later version of this API, ::cuDeviceGetUuid_v2. It will * supplant this version in 12.0, which is retained for minor version compatibility. * - * Returns 16-octets identifing the device \p dev in the structure + * Returns 16-octets identifying the device \p dev in the structure * pointed by the \p uuid. * * \param uuid - Returned UUID @@ -3799,7 +4288,7 @@ CUresult CUDAAPI cuDeviceGetUuid(CUuuid *uuid, CUdevice dev); /** * \brief Return an UUID for the device (11.4+) * - * Returns 16-octets identifing the device \p dev in the structure + * Returns 16-octets identifying the device \p dev in the structure * pointed by the \p uuid. If the device is in MIG mode, returns its * MIG UUID which uniquely identifies the subscribed MIG compute instance. * @@ -4080,7 +4569,7 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements, * supports native atomic operations. * - ::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO: Ratio of single precision performance * (in floating-point operations per second) to double precision performance. - * - ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS: Device suppports coherently accessing + * - ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS: Device supports coherently accessing * pageable memory without calling cudaHostRegister on it. * - ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS: Device can coherently access managed memory * concurrently with the CPU. @@ -4088,7 +4577,7 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements, * - ::CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM: Device can access host registered * memory at the same virtual address as the CPU. * - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN: The maximum per block shared memory size - * suported on this device. This is the maximum value that can be opted into when using the cuFuncSetAttribute() call. + * supported on this device. This is the maximum value that can be opted into when using the cuFuncSetAttribute() or cuKernelSetAttribute() call. * For more details see ::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES * - ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES: Device accesses pageable memory via the host's * page tables. @@ -4110,9 +4599,7 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements, * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_FLUSH_WRITES_OPTIONS: The returned attribute shall be interpreted as a bitmask, where the individual bits are described by the ::CUflushGPUDirectRDMAWritesOptions enum * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WRITES_ORDERING: GPUDirect RDMA writes to the device do not need to be flushed for consumers within the scope indicated by the returned attribute. See ::CUGPUDirectRDMAWritesOrdering for the numerical values returned here. * - ::CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES: Bitmask of handle types supported with mempool based IPC - * - ::CU_DEVICE_ATTRIBUTE_DEFERRED_MAPPING_CUDA_ARRAY_SUPPORTED: Device supports deferred mapping CUDA arrays and CUDA mipmapped arrays. - * * \param pi - Returned device attribute value * \param attrib - Device attribute to query @@ -4165,6 +4652,20 @@ CUresult CUDAAPI cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevi * to one another: a developer may set both these flags that allows to * set both wait and signal specific attributes in the same \p nvSciSyncAttrList. * + * Note that this API updates the input \p nvSciSyncAttrList with values equivalent + * to the following public attribute key-values: + * NvSciSyncAttrKey_RequiredPerm is set to + * - NvSciSyncAccessPerm_SignalOnly if ::CUDA_NVSCISYNC_ATTR_SIGNAL is set in \p flags. + * - NvSciSyncAccessPerm_WaitOnly if ::CUDA_NVSCISYNC_ATTR_WAIT is set in \p flags. + * - NvSciSyncAccessPerm_WaitSignal if both ::CUDA_NVSCISYNC_ATTR_WAIT and + * ::CUDA_NVSCISYNC_ATTR_SIGNAL are set in \p flags. + * NvSciSyncAttrKey_PrimitiveInfo is set to + * - NvSciSyncAttrValPrimitiveType_SysmemSemaphore on any valid \p device. + * - NvSciSyncAttrValPrimitiveType_Syncpoint if \p device is a Tegra device. + * - NvSciSyncAttrValPrimitiveType_SysmemSemaphorePayload64b if \p device is GA10X+. + * NvSciSyncAttrKey_GpuId is set to the same UUID that is returned for this + * \p device from ::cuDeviceGetUuid. + * * \param nvSciSyncAttrList - Return NvSciSync attributes supported. * \param dev - Valid Cuda Device to get NvSciSync attributes for. * \param flags - flags describing NvSciSync usage. @@ -4240,6 +4741,37 @@ CUresult CUDAAPI cuDeviceGetMemPool(CUmemoryPool *pool, CUdevice dev); */ CUresult CUDAAPI cuDeviceGetDefaultMemPool(CUmemoryPool *pool_out, CUdevice dev); +/** + * \brief Returns information about the execution affinity support of the device. + * + * Returns in \p *pi whether execution affinity type \p type is supported by device \p dev. + * The supported types are: + * - ::CU_EXEC_AFFINITY_TYPE_SM_COUNT: 1 if context with limited SMs is supported by the device, + * or 0 if not; + * + * \param pi - 1 if the execution affinity type \p type is supported by the device, or 0 if not + * \param type - Execution affinity type to query + * \param dev - Device handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGet, + * ::cuDeviceTotalMem + */ +CUresult CUDAAPI cuDeviceGetExecAffinitySupport(int *pi, CUexecAffinityType type, CUdevice dev); + /** * \brief Blocks until remote writes are visible to the specified scope * @@ -4354,7 +4886,7 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuDeviceGetProperties(CUdevprop *prop, CUdevi * * \deprecated * - * This function was deprecated as of CUDA 5.0 and its functionality superceded + * This function was deprecated as of CUDA 5.0 and its functionality superseded * by ::cuDeviceGetAttribute(). * * Returns in \p *major and \p *minor the major and minor revision numbers that @@ -4537,6 +5069,31 @@ CUresult CUDAAPI cuDevicePrimaryCtxRelease(CUdevice dev); * Deprecated: This flag is deprecated and the behavior enabled * by this flag is now the default and cannot be disabled. * + * - ::CU_CTX_COREDUMP_ENABLE: If GPU coredumps have not been enabled globally + * with ::cuCoredumpSetAttributeGlobal or environment variables, this flag can + * be set during context creation to instruct CUDA to create a coredump if + * this context raises an exception during execution. These environment variables + * are described in the CUDA-GDB user guide under the "GPU core dump support" + * section. + * The initial settings will be taken from the global settings at the time of + * context creation. The other settings that control coredump output can be + * modified by calling ::cuCoredumpSetAttribute from the created context after + * it becomes current. + * + * - ::CU_CTX_USER_COREDUMP_ENABLE: If user-triggered GPU coredumps have not + * been enabled globally with ::cuCoredumpSetAttributeGlobal or environment + * variables, this flag can be set during context creation to instruct CUDA to + * create a coredump if data is written to a certain pipe that is present in the + * OS space. These environment variables are described in the CUDA-GDB user + * guide under the "GPU core dump support" section. + * It is important to note that the pipe name *must* be set with + * ::cuCoredumpSetAttributeGlobal before creating the context if this flag is + * used. Setting this flag implies that ::CU_CTX_COREDUMP_ENABLE is set. + * The initial settings will be taken from the global settings at the time of + * context creation. The other settings that control coredump output can be + * modified by calling ::cuCoredumpSetAttribute from the created context after + * it becomes current. + * * \param dev - Device for which the primary context flags are set * \param flags - New flags for the device * @@ -4552,6 +5109,7 @@ CUresult CUDAAPI cuDevicePrimaryCtxRelease(CUdevice dev); * ::cuDevicePrimaryCtxGetState, * ::cuCtxCreate, * ::cuCtxGetFlags, + * ::cuCtxSetFlags, * ::cudaSetDeviceFlags */ CUresult CUDAAPI cuDevicePrimaryCtxSetFlags(CUdevice dev, unsigned int flags); @@ -4578,6 +5136,7 @@ CUresult CUDAAPI cuDevicePrimaryCtxSetFlags(CUdevice dev, unsigned int flags); * \sa * ::cuDevicePrimaryCtxSetFlags, * ::cuCtxGetFlags, + * ::cuCtxSetFlags, * ::cudaGetDeviceFlags */ CUresult CUDAAPI cuDevicePrimaryCtxGetState(CUdevice dev, unsigned int *flags, int *active); @@ -4624,37 +5183,6 @@ CUresult CUDAAPI cuDevicePrimaryCtxReset(CUdevice dev); /** @} */ /* END CUDA_PRIMARY_CTX */ -/** - * \brief Returns information about the execution affinity support of the device. - * - * Returns in \p *pi whether execution affinity type \p type is supported by device \p dev. - * The supported types are: - * - ::CU_EXEC_AFFINITY_TYPE_SM_COUNT: 1 if context with limited SMs is supported by the device, - * or 0 if not; - * - * \param pi - 1 if the execution affinity type \p type is supported by the device, or 0 if not - * \param type - Execution affinity type to query - * \param dev - Device handle - * - * \return - * ::CUDA_SUCCESS, - * ::CUDA_ERROR_DEINITIALIZED, - * ::CUDA_ERROR_NOT_INITIALIZED, - * ::CUDA_ERROR_INVALID_CONTEXT, - * ::CUDA_ERROR_INVALID_VALUE, - * ::CUDA_ERROR_INVALID_DEVICE - * \notefnerr - * - * \sa - * ::cuDeviceGetAttribute, - * ::cuDeviceGetCount, - * ::cuDeviceGetName, - * ::cuDeviceGetUuid, - * ::cuDeviceGet, - * ::cuDeviceTotalMem - */ -CUresult CUDAAPI cuDeviceGetExecAffinitySupport(int *pi, CUexecAffinityType type, CUdevice dev); - /** * \defgroup CUDA_CTX Context Management * @@ -4677,7 +5205,7 @@ CUresult CUDAAPI cuDeviceGetExecAffinitySupport(int *pi, CUexecAffinityType type * * Creates a new CUDA context and associates it with the calling thread. The * \p flags parameter is described below. The context is created with a usage - * count of 1 and the caller of ::cuCtxCreate() must call ::cuCtxDestroy() or + * count of 1 and the caller of ::cuCtxCreate() must call ::cuCtxDestroy() * when done using the context. If a context is already current to the thread, * it is supplanted by the newly created context and may be restored by a subsequent * call to ::cuCtxPopCurrent(). @@ -4727,6 +5255,33 @@ CUresult CUDAAPI cuDeviceGetExecAffinitySupport(int *pi, CUexecAffinityType type * by this flag is now the default and cannot be disabled. * Instead, the per-thread stack size can be controlled with ::cuCtxSetLimit(). * + * - ::CU_CTX_COREDUMP_ENABLE: If GPU coredumps have not been enabled globally + * with ::cuCoredumpSetAttributeGlobal or environment variables, this flag can + * be set during context creation to instruct CUDA to create a coredump if + * this context raises an exception during execution. These environment variables + * are described in the CUDA-GDB user guide under the "GPU core dump support" + * section. + * The initial attributes will be taken from the global attributes at the time of + * context creation. The other attributes that control coredump output can be + * modified by calling ::cuCoredumpSetAttribute from the created context after + * it becomes current. + * + * - ::CU_CTX_USER_COREDUMP_ENABLE: If user-triggered GPU coredumps have not + * been enabled globally with ::cuCoredumpSetAttributeGlobal or environment + * variables, this flag can be set during context creation to instruct CUDA to + * create a coredump if data is written to a certain pipe that is present in the + * OS space. These environment variables are described in the CUDA-GDB user + * guide under the "GPU core dump support" section. + * It is important to note that the pipe name *must* be set with + * ::cuCoredumpSetAttributeGlobal before creating the context if this flag is + * used. Setting this flag implies that ::CU_CTX_COREDUMP_ENABLE is set. + * The initial attributes will be taken from the global attributes at the time of + * context creation. The other attributes that control coredump output can be + * modified by calling ::cuCoredumpSetAttribute from the created context after + * it becomes current. + * Setting this flag on any context creation is equivalent to setting the + * ::CU_COREDUMP_ENABLE_USER_TRIGGER attribute to \p true globally. + * * Context creation will fail with ::CUDA_ERROR_UNKNOWN if the compute mode of * the device is ::CU_COMPUTEMODE_PROHIBITED. The function ::cuDeviceGetAttribute() * can be used with ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE to determine the @@ -4760,6 +5315,8 @@ CUresult CUDAAPI cuDeviceGetExecAffinitySupport(int *pi, CUexecAffinityType type * ::cuCtxPushCurrent, * ::cuCtxSetCacheConfig, * ::cuCtxSetLimit, + * ::cuCoredumpSetAttributeGlobal, + * ::cuCoredumpSetAttribute, * ::cuCtxSynchronize */ CUresult CUDAAPI cuCtxCreate(CUcontext *pctx, unsigned int flags, CUdevice dev); @@ -4770,7 +5327,7 @@ CUresult CUDAAPI cuCtxCreate(CUcontext *pctx, unsigned int flags, CUdevice dev); * Creates a new CUDA context with execution affinity and associates it with * the calling thread. The \p paramsArray and \p flags parameter are described below. * The context is created with a usage count of 1 and the caller of ::cuCtxCreate() must - * call ::cuCtxDestroy() or when done using the context. If a context is already + * call ::cuCtxDestroy() when done using the context. If a context is already * current to the thread, it is supplanted by the newly created context and may * be restored by a subsequent call to ::cuCtxPopCurrent(). * @@ -4830,6 +5387,33 @@ CUresult CUDAAPI cuCtxCreate(CUcontext *pctx, unsigned int flags, CUdevice dev); * by this flag is now the default and cannot be disabled. * Instead, the per-thread stack size can be controlled with ::cuCtxSetLimit(). * + * - ::CU_CTX_COREDUMP_ENABLE: If GPU coredumps have not been enabled globally + * with ::cuCoredumpSetAttributeGlobal or environment variables, this flag can + * be set during context creation to instruct CUDA to create a coredump if + * this context raises an exception during execution. These environment variables + * are described in the CUDA-GDB user guide under the "GPU core dump support" + * section. + * The initial attributes will be taken from the global attributes at the time of + * context creation. The other attributes that control coredump output can be + * modified by calling ::cuCoredumpSetAttribute from the created context after + * it becomes current. + * + * - ::CU_CTX_USER_COREDUMP_ENABLE: If user-triggered GPU coredumps have not + * been enabled globally with ::cuCoredumpSetAttributeGlobal or environment + * variables, this flag can be set during context creation to instruct CUDA to + * create a coredump if data is written to a certain pipe that is present in the + * OS space. These environment variables are described in the CUDA-GDB user + * guide under the "GPU core dump support" section. + * It is important to note that the pipe name *must* be set with + * ::cuCoredumpSetAttributeGlobal before creating the context if this flag is + * used. Setting this flag implies that ::CU_CTX_COREDUMP_ENABLE is set. + * The initial attributes will be taken from the global attributes at the time of + * context creation. The other attributes that control coredump output can be + * modified by calling ::cuCoredumpSetAttribute from the created context after + * it becomes current. + * Setting this flag on any context creation is equivalent to setting the + * ::CU_COREDUMP_ENABLE_USER_TRIGGER attribute to \p true globally. + * * Context creation will fail with ::CUDA_ERROR_UNKNOWN if the compute mode of * the device is ::CU_COMPUTEMODE_PROHIBITED. The function ::cuDeviceGetAttribute() * can be used with ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE to determine the @@ -4867,6 +5451,8 @@ CUresult CUDAAPI cuCtxCreate(CUcontext *pctx, unsigned int flags, CUdevice dev); * ::cuCtxSetCacheConfig, * ::cuCtxSetLimit, * ::cuCtxSynchronize, + * ::cuCoredumpSetAttributeGlobal, + * ::cuCoredumpSetAttribute, * ::CUexecAffinityParam */ CUresult CUDAAPI cuCtxCreate_v3(CUcontext *pctx, CUexecAffinityParam *paramsArray, int numParams, unsigned int flags, CUdevice dev); @@ -5091,10 +5677,69 @@ CUresult CUDAAPI cuCtxGetDevice(CUdevice *device); * ::cuCtxGetLimit, * ::cuCtxGetSharedMemConfig, * ::cuCtxGetStreamPriorityRange, + * ::cuCtxSetFlags, * ::cudaGetDeviceFlags */ CUresult CUDAAPI cuCtxGetFlags(unsigned int *flags); +/** + * \brief Sets the flags for the current context + * + * \param flags - Flags to set on the current context + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetCurrent, + * ::cuCtxGetDevice, + * ::cuCtxGetLimit, + * ::cuCtxGetSharedMemConfig, + * ::cuCtxGetStreamPriorityRange, + * ::cuCtxGetFlags, + * ::cudaGetDeviceFlags, + * ::cuDevicePrimaryCtxSetFlags, + */ +CUresult CUDAAPI cuCtxSetFlags(unsigned int flags); + +/** + * \brief Returns the unique Id associated with the context supplied + * + * Returns in \p ctxId the unique Id which is associated with a given context. + * The Id is unique for the life of the program for this instance of CUDA. + * If context is supplied as NULL and there is one current, the Id of the + * current context is returned. + * + * \param ctx - Context for which to obtain the Id + * \param ctxId - Pointer to store the Id of the context + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPushCurrent + */ +CUresult CUDAAPI cuCtxGetId(CUcontext ctx, unsigned long long *ctxId); + /** * \brief Block for a context's tasks to complete * @@ -5168,9 +5813,9 @@ CUresult CUDAAPI cuCtxSynchronize(void); * memory which can no longer be used for user allocations. If these * reservations of device memory fail, ::cuCtxSetLimit() will return * ::CUDA_ERROR_OUT_OF_MEMORY, and the limit can be reset to a lower value. - * This limit is only applicable to devices of compute capability 3.5 and - * higher. Attempting to set this limit on devices of compute capability less - * than 3.5 will result in the error ::CUDA_ERROR_UNSUPPORTED_LIMIT being + * This limit is only applicable to devices of compute capability < 9.0. + * Attempting to set this limit on devices of other compute capability + * versions will result in the error ::CUDA_ERROR_UNSUPPORTED_LIMIT being * returned. * * - ::CU_LIMIT_DEV_RUNTIME_PENDING_LAUNCH_COUNT controls the maximum number of @@ -5191,10 +5836,10 @@ CUresult CUDAAPI cuCtxSynchronize(void); * returned. * * - ::CU_LIMIT_MAX_L2_FETCH_GRANULARITY controls the L2 cache fetch granularity. - * Values can range from 0B to 128B. This is purely a performence hint and + * Values can range from 0B to 128B. This is purely a performance hint and * it can be ignored or clamped depending on the platform. * - * - ::CU_LIMIT_PERSISTING_L2_CACHE_SIZE controls size in bytes availabe for + * - ::CU_LIMIT_PERSISTING_L2_CACHE_SIZE controls size in bytes available for * persisting L2 cache. This is purely a performance hint and it can be * ignored or clamped depending on the platform. * @@ -5318,7 +5963,7 @@ CUresult CUDAAPI cuCtxGetCacheConfig(CUfunc_cache *pconfig); * the current context. This is only a preference. The driver will use * the requested configuration if possible, but it is free to choose a different * configuration if required to execute the function. Any function preference - * set via ::cuFuncSetCacheConfig() will be preferred over this context-wide + * set via ::cuFuncSetCacheConfig() or ::cuKernelSetCacheConfig() will be preferred over this context-wide * setting. Setting the context-wide cache configuration to * ::CU_FUNC_CACHE_PREFER_NONE will cause subsequent kernel launches to prefer * to not change the cache configuration unless required to launch the kernel. @@ -5357,7 +6002,8 @@ CUresult CUDAAPI cuCtxGetCacheConfig(CUfunc_cache *pconfig); * ::cuCtxSetLimit, * ::cuCtxSynchronize, * ::cuFuncSetCacheConfig, - * ::cudaDeviceSetCacheConfig + * ::cudaDeviceSetCacheConfig, + * ::cuKernelSetCacheConfig */ CUresult CUDAAPI cuCtxSetCacheConfig(CUfunc_cache config); @@ -5873,6 +6519,32 @@ CUresult CUDAAPI cuModuleLoadFatBinary(CUmodule *module, const void *fatCubin); */ CUresult CUDAAPI cuModuleUnload(CUmodule hmod); +/** + * CUDA Lazy Loading status + */ +typedef enum CUmoduleLoadingMode_enum { + CU_MODULE_EAGER_LOADING = 0x1, /**< Lazy Kernel Loading is not enabled */ + CU_MODULE_LAZY_LOADING = 0x2, /**< Lazy Kernel Loading is enabled */ +} CUmoduleLoadingMode; + +/** + * \brief Query lazy loading mode + * + * Returns lazy loading mode + * Module loading mode is controlled by CUDA_MODULE_LOADING env variable + * + * \param mode - Returns the lazy loading mode + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + * \sa + * ::cuModuleLoad, + */ +CUresult CUDAAPI cuModuleGetLoadingMode(CUmoduleLoadingMode *mode); + /** * \brief Returns a function handle * @@ -5908,9 +6580,9 @@ CUresult CUDAAPI cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const cha * * Returns in \p *dptr and \p *bytes the base pointer and size of the * global of name \p name located in module \p hmod. If no variable of that name - * exists, ::cuModuleGetGlobal() returns ::CUDA_ERROR_NOT_FOUND. Both - * parameters \p dptr and \p bytes are optional. If one of them is - * NULL, it is ignored. + * exists, ::cuModuleGetGlobal() returns ::CUDA_ERROR_NOT_FOUND. + * One of the parameters \p dptr or \p bytes (not both) can be NULL in which + * case it is ignored. * * \param dptr - Returned global device pointer * \param bytes - Returned global size in bytes @@ -5939,77 +6611,11 @@ CUresult CUDAAPI cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const cha CUresult CUDAAPI cuModuleGetGlobal(CUdeviceptr *dptr, size_t *bytes, CUmodule hmod, const char *name); /** - * \brief Returns a handle to a texture reference - * - * Returns in \p *pTexRef the handle of the texture reference of name \p name - * in the module \p hmod. If no texture reference of that name exists, - * ::cuModuleGetTexRef() returns ::CUDA_ERROR_NOT_FOUND. This texture reference - * handle should not be destroyed, since it will be destroyed when the module - * is unloaded. + * \brief Creates a pending JIT linker invocation. * - * \param pTexRef - Returned texture reference - * \param hmod - Module to retrieve texture reference from - * \param name - Name of texture reference to retrieve - * - * \return - * ::CUDA_SUCCESS, - * ::CUDA_ERROR_DEINITIALIZED, - * ::CUDA_ERROR_NOT_INITIALIZED, - * ::CUDA_ERROR_INVALID_CONTEXT, - * ::CUDA_ERROR_INVALID_VALUE, - * ::CUDA_ERROR_NOT_FOUND - * \notefnerr - * - * \sa ::cuModuleGetFunction, - * ::cuModuleGetGlobal, - * ::cuModuleGetSurfRef, - * ::cuModuleLoad, - * ::cuModuleLoadData, - * ::cuModuleLoadDataEx, - * ::cuModuleLoadFatBinary, - * ::cuModuleUnload, - * ::cudaGetTextureReference - */ -CUresult CUDAAPI cuModuleGetTexRef(CUtexref *pTexRef, CUmodule hmod, const char *name); - -/** - * \brief Returns a handle to a surface reference - * - * Returns in \p *pSurfRef the handle of the surface reference of name \p name - * in the module \p hmod. If no surface reference of that name exists, - * ::cuModuleGetSurfRef() returns ::CUDA_ERROR_NOT_FOUND. - * - * \param pSurfRef - Returned surface reference - * \param hmod - Module to retrieve surface reference from - * \param name - Name of surface reference to retrieve - * - * \return - * ::CUDA_SUCCESS, - * ::CUDA_ERROR_DEINITIALIZED, - * ::CUDA_ERROR_NOT_INITIALIZED, - * ::CUDA_ERROR_INVALID_CONTEXT, - * ::CUDA_ERROR_INVALID_VALUE, - * ::CUDA_ERROR_NOT_FOUND - * \notefnerr - * - * \sa ::cuModuleGetFunction, - * ::cuModuleGetGlobal, - * ::cuModuleGetTexRef, - * ::cuModuleLoad, - * ::cuModuleLoadData, - * ::cuModuleLoadDataEx, - * ::cuModuleLoadFatBinary, - * ::cuModuleUnload, - * ::cudaGetSurfaceReference - */ -CUresult CUDAAPI cuModuleGetSurfRef(CUsurfref *pSurfRef, CUmodule hmod, const char *name); - -/** - * \brief Creates a pending JIT linker invocation. - * - * If the call is successful, the caller owns the returned CUlinkState, which - * should eventually be destroyed with ::cuLinkDestroy. The - * device code machine size (32 or 64 bit) will match the calling application. + * If the call is successful, the caller owns the returned CUlinkState, which + * should eventually be destroyed with ::cuLinkDestroy. The + * device code machine size (32 or 64 bit) will match the calling application. * * Both linker and compiler options may be specified. Compiler options will * be applied to inputs to this linker action which must be compiled from PTX. @@ -6021,6 +6627,8 @@ CUresult CUDAAPI cuModuleGetSurfRef(CUsurfref *pSurfRef, CUmodule hmod, const ch * options are used. No other references to inputs are maintained after this * call returns. * + * \note For LTO-IR input, only LTO-IR compiled with toolkits prior to CUDA 12.0 will be accepted + * * \param numOptions Size of options arrays * \param options Array of linker and compiler options * \param optionValues Array of option values, each cast to void * @@ -6056,6 +6664,8 @@ cuLinkCreate(unsigned int numOptions, CUjit_option *options, void **optionValues * ::CU_JIT_WALL_TIME, ::CU_JIT_INFO_LOG_BUFFER, ::CU_JIT_ERROR_LOG_BUFFER, * ::CU_JIT_TARGET_FROM_CUCONTEXT, or ::CU_JIT_TARGET. * + * \note For LTO-IR input, only LTO-IR compiled with toolkits prior to CUDA 12.0 will be accepted + * * \param state A pending linker action. * \param type The type of the input data. * \param data The input data. PTX must be NULL-terminated. @@ -6097,6 +6707,8 @@ cuLinkAddData(CUlinkState state, CUjitInputType type, void *data, size_t size, c * This method is equivalent to invoking ::cuLinkAddData on the contents * of the file. * + * \note For LTO-IR input, only LTO-IR compiled with toolkits prior to CUDA 12.0 will be accepted + * * \param state A pending linker action * \param type The type of the input data * \param path Path to the input file @@ -6166,6 +6778,619 @@ cuLinkDestroy(CUlinkState state); /** @} */ /* END CUDA_MODULE */ +/** + * \defgroup CUDA_MODULE_DEPRECATED Module Management [DEPRECATED] + * + * ___MANBRIEF___ deprecated module management functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the deprecated module management functions of the low-level + * CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Returns a handle to a texture reference + * + * \deprecated + * + * Returns in \p *pTexRef the handle of the texture reference of name \p name + * in the module \p hmod. If no texture reference of that name exists, + * ::cuModuleGetTexRef() returns ::CUDA_ERROR_NOT_FOUND. This texture reference + * handle should not be destroyed, since it will be destroyed when the module + * is unloaded. + * + * \param pTexRef - Returned texture reference + * \param hmod - Module to retrieve texture reference from + * \param name - Name of texture reference to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_FOUND + * \notefnerr + * + * \sa + * ::cuModuleGetFunction, + * ::cuModuleGetGlobal, + * ::cuModuleGetSurfRef, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx, + * ::cuModuleLoadFatBinary, + * ::cuModuleUnload + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuModuleGetTexRef(CUtexref *pTexRef, CUmodule hmod, const char *name); + +/** + * \brief Returns a handle to a surface reference + * + * \deprecated + * + * Returns in \p *pSurfRef the handle of the surface reference of name \p name + * in the module \p hmod. If no surface reference of that name exists, + * ::cuModuleGetSurfRef() returns ::CUDA_ERROR_NOT_FOUND. + * + * \param pSurfRef - Returned surface reference + * \param hmod - Module to retrieve surface reference from + * \param name - Name of surface reference to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_FOUND + * \notefnerr + * + * \sa + * ::cuModuleGetFunction, + * ::cuModuleGetGlobal, + * ::cuModuleGetTexRef, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx, + * ::cuModuleLoadFatBinary, + * ::cuModuleUnload + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuModuleGetSurfRef(CUsurfref *pSurfRef, CUmodule hmod, const char *name); + +/** @} */ /* END CUDA_MODULE_DEPRECATED */ + +/** + * \defgroup CUDA_LIBRARY Library Management + * + * ___MANBRIEF___ library management functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the library management functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Load a library with specified code and options + * + * Takes a pointer \p code and loads the corresponding library \p library into + * all contexts existent at the time of the call and future contexts at the time + * of creation until the library is unloaded with ::cuLibraryUnload(). + * + * The pointer may be obtained by mapping a \e cubin or \e PTX or \e fatbin file, + * passing a \e cubin or \e PTX or \e fatbin file as a NULL-terminated text string, or + * incorporating a \e cubin or \e fatbin object into the executable resources and + * using operating system calls such as Windows \c FindResource() to obtain the pointer. + * + * Options are passed as an array via \p jitOptions and any corresponding parameters are passed in + * \p jitOptionsValues. The number of total JIT options is supplied via \p numJitOptions. + * Any outputs will be returned via \p jitOptionsValues. + * + * Library load options are passed as an array via \p libraryOptions and any corresponding parameters are passed in + * \p libraryOptionValues. The number of total library load options is supplied via \p numLibraryOptions. + * + * \param library - Returned library + * \param code - Code to load + * \param jitOptions - Options for JIT + * \param jitOptionsValues - Option values for JIT + * \param numJitOptions - Number of options + * \param libraryOptions - Options for loading + * \param libraryOptionValues - Option values for loading + * \param numLibraryOptions - Number of options for loading + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NO_BINARY_FOR_GPU, + * ::CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_JIT_COMPILER_NOT_FOUND + * + * \sa ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx + */ +CUresult CUDAAPI cuLibraryLoadData(CUlibrary *library, const void *code, + CUjit_option *jitOptions, void **jitOptionsValues, unsigned int numJitOptions, + CUlibraryOption *libraryOptions, void** libraryOptionValues, unsigned int numLibraryOptions); + +/** + * \brief Load a library with specified file and options + * + * Takes a filename \p fileName and loads the corresponding library \p library into + * all contexts existent at the time of the call and future contexts at the time of + * creation until the library is unloaded with ::cuLibraryUnload(). + * + * The file should be a \e cubin file as output by \b nvcc, or a \e PTX file either + * as output by \b nvcc or handwritten, or a \e fatbin file as output by \b nvcc + * from toolchain 4.0 or later. + * + * Options are passed as an array via \p jitOptions and any corresponding parameters are + * passed in \p jitOptionsValues. The number of total options is supplied via \p numJitOptions. + * Any outputs will be returned via \p jitOptionsValues. + * + * Library load options are passed as an array via \p libraryOptions and any corresponding parameters are passed in + * \p libraryOptionValues. The number of total library load options is supplied via \p numLibraryOptions. + * + * \param library - Returned library + * \param fileName - File to load from + * \param jitOptions - Options for JIT + * \param jitOptionsValues - Option values for JIT + * \param numJitOptions - Number of options + * \param libraryOptions - Options for loading + * \param libraryOptionValues - Option values for loading + * \param numLibraryOptions - Number of options for loading + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NO_BINARY_FOR_GPU, + * ::CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_JIT_COMPILER_NOT_FOUND + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryUnload, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx + */ +CUresult CUDAAPI cuLibraryLoadFromFile(CUlibrary *library, const char *fileName, + CUjit_option *jitOptions, void **jitOptionsValues, unsigned int numJitOptions, + CUlibraryOption *libraryOptions, void **libraryOptionValues, unsigned int numLibraryOptions); + +/** + * \brief Unloads a library + * + * Unloads the library specified with \p library + * + * \param library - Library to unload + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuModuleUnload + */ +CUresult CUDAAPI cuLibraryUnload(CUlibrary library); + +/** + * \brief Returns a kernel handle + * + * Returns in \p pKernel the handle of the kernel with name \p name located in library \p library. + * If kernel handle is not found, the call returns ::CUDA_ERROR_NOT_FOUND. + * + * \param pKernel - Returned kernel handle + * \param library - Library to retrieve kernel from + * \param name - Name of kernel to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_FOUND, + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + * ::cuKernelGetFunction, + * ::cuLibraryGetModule, + * ::cuModuleGetFunction + */ +CUresult CUDAAPI cuLibraryGetKernel(CUkernel *pKernel, CUlibrary library, const char *name); + +/** + * \brief Returns a module handle + * + * Returns in \p pMod the module handle associated with the current context located in + * library \p library. If module handle is not found, the call returns ::CUDA_ERROR_NOT_FOUND. + * + * \param pMod - Returned module handle + * \param library - Library to retrieve module from + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_FOUND, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + * ::cuModuleGetFunction + */ +CUresult CUDAAPI cuLibraryGetModule(CUmodule *pMod, CUlibrary library); + +/** + * \brief Returns a function handle + * + * Returns in \p pFunc the handle of the function for the requested kernel \p kernel and + * the current context. If function handle is not found, the call returns ::CUDA_ERROR_NOT_FOUND. + * + * \param pFunc - Returned function handle + * \param kernel - Kernel to retrieve function for the requested context + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_FOUND, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + * ::cuLibraryGetKernel, + * ::cuLibraryGetModule, + * ::cuModuleGetFunction + */ +CUresult CUDAAPI cuKernelGetFunction(CUfunction *pFunc, CUkernel kernel); + +/** + * \brief Returns a global device pointer + * + * Returns in \p *dptr and \p *bytes the base pointer and size of the global with + * name \p name for the requested library \p library and the current context. + * If no global for the requested name \p name exists, the call returns ::CUDA_ERROR_NOT_FOUND. + * One of the parameters \p dptr or \p bytes (not both) can be NULL in which + * case it is ignored. + * + * \param dptr - Returned global device pointer for the requested context + * \param bytes - Returned global size in bytes + * \param library - Library to retrieve global from + * \param name - Name of global to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_FOUND, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + * ::cuLibraryGetModule, + * cuModuleGetGlobal + */ +CUresult CUDAAPI cuLibraryGetGlobal(CUdeviceptr *dptr, size_t *bytes, CUlibrary library, const char *name); + +/** + * \brief Returns a pointer to managed memory + * + * Returns in \p *dptr and \p *bytes the base pointer and size of the managed memory with + * name \p name for the requested library \p library. If no managed memory with the + * requested name \p name exists, the call returns ::CUDA_ERROR_NOT_FOUND. One of the parameters + * \p dptr or \p bytes (not both) can be NULL in which case it is ignored. + * Note that managed memory for library \p library is shared across devices and is registered + * when the library is loaded into atleast one context. + * + * \note The API requires a CUDA context to be present and initialized on at least one device. + * If no context is present, the call returns ::CUDA_ERROR_NOT_FOUND. + * + * \param dptr - Returned pointer to the managed memory + * \param bytes - Returned memory size in bytes + * \param library - Library to retrieve managed memory from + * \param name - Name of managed memory to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_FOUND, + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + */ +CUresult CUDAAPI cuLibraryGetManaged(CUdeviceptr *dptr, size_t *bytes, CUlibrary library, const char *name); + +/** + * \brief Returns a pointer to a unified function + * + * Returns in \p *fptr the function pointer to a unified function denoted by \p symbol. + * If no unified function with name \p symbol exists, the call returns ::CUDA_ERROR_NOT_FOUND. + * If there is no device with attribute ::CU_DEVICE_ATTRIBUTE_UNIFIED_FUNCTION_POINTERS present in the system, + * the call may return ::CUDA_ERROR_NOT_FOUND. + * + * \param fptr - Returned pointer to a unified function + * \param library - Library to retrieve function pointer memory from + * \param symbol - Name of function pointer to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_FOUND, + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + */ +CUresult CUDAAPI cuLibraryGetUnifiedFunction(void **fptr, CUlibrary library, const char *symbol); + +/** + * \brief Returns information about a kernel + * + * Returns in \p *pi the integer value of the attribute \p attrib for the kernel + * \p kernel for the requested device \p dev. The supported attributes are: + * - ::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK: The maximum number of threads + * per block, beyond which a launch of the kernel would fail. This number + * depends on both the kernel and the requested device. + * - ::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES: The size in bytes of + * statically-allocated shared memory per block required by this kernel. + * This does not include dynamically-allocated shared memory requested by + * the user at runtime. + * - ::CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES: The size in bytes of user-allocated + * constant memory required by this kernel. + * - ::CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES: The size in bytes of local memory + * used by each thread of this kernel. + * - ::CU_FUNC_ATTRIBUTE_NUM_REGS: The number of registers used by each thread + * of this kernel. + * - ::CU_FUNC_ATTRIBUTE_PTX_VERSION: The PTX virtual architecture version for + * which the kernel was compiled. This value is the major PTX version * 10 + * + the minor PTX version, so a PTX version 1.3 function would return the + * value 13. Note that this may return the undefined value of 0 for cubins + * compiled prior to CUDA 3.0. + * - ::CU_FUNC_ATTRIBUTE_BINARY_VERSION: The binary architecture version for + * which the kernel was compiled. This value is the major binary + * version * 10 + the minor binary version, so a binary version 1.3 function + * would return the value 13. Note that this will return a value of 10 for + * legacy cubins that do not have a properly-encoded binary architecture + * version. + * - ::CU_FUNC_CACHE_MODE_CA: The attribute to indicate whether the kernel has + * been compiled with user specified option "-Xptxas --dlcm=ca" set. + * - ::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES: The maximum size in bytes of + * dynamically-allocated shared memory. + * - ::CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT: Preferred shared memory-L1 + * cache split ratio in percent of total shared memory. + * - ::CU_FUNC_ATTRIBUTE_CLUSTER_SIZE_MUST_BE_SET: If this attribute is set, the + * kernel must launch with a valid cluster size specified. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_WIDTH: The required cluster width in + * blocks. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_HEIGHT: The required cluster height in + * blocks. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_DEPTH: The required cluster depth in + * blocks. + * - ::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED: Indicates whether + * the function can be launched with non-portable cluster size. 1 is allowed, + * 0 is disallowed. A non-portable cluster size may only function on the + * specific SKUs the program is tested on. The launch might fail if the + * program is run on a different hardware platform. CUDA API provides + * cudaOccupancyMaxActiveClusters to assist with checking whether the desired + * size can be launched on the current device. A portable cluster size is + * guaranteed to be functional on all compute capabilities higher than the + * target compute capability. The portable cluster size for sm_90 is 8 blocks + * per cluster. This value may increase for future compute capabilities. The + * specific hardware unit may support higher cluster sizes that’s not + * guaranteed to be portable. + * - ::CU_FUNC_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE: The block + * scheduling policy of a function. The value type is CUclusterSchedulingPolicy. + * + * \note If another thread is trying to set the same attribute on the same device using + * ::cuKernelSetAttribute() simultaneously, the attribute query will give the old or new + * value depending on the interleavings chosen by the OS scheduler and memory consistency. + * + * \param pi - Returned attribute value + * \param attrib - Attribute requested + * \param kernel - Kernel to query attribute of + * \param dev - Device to query attribute of + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + * ::cuKernelSetAttribute, + * ::cuLibraryGetKernel, + * ::cuLaunchKernel, + * ::cuKernelGetFunction, + * ::cuLibraryGetModule, + * ::cuModuleGetFunction, + * ::cuFuncGetAttribute + */ +CUresult CUDAAPI cuKernelGetAttribute(int *pi, CUfunction_attribute attrib, CUkernel kernel, CUdevice dev); + +/** + * \brief Sets information about a kernel + * + * This call sets the value of a specified attribute \p attrib on the kernel \p kernel + * for the requested device \p dev to an integer value specified by \p val. + * This function returns CUDA_SUCCESS if the new value of the attribute could be + * successfully set. If the set fails, this call will return an error. + * Not all attributes can have values set. Attempting to set a value on a read-only + * attribute will result in an error (CUDA_ERROR_INVALID_VALUE) + * + * Note that attributes set using ::cuFuncSetAttribute() will override the attribute + * set by this API irrespective of whether the call to ::cuFuncSetAttribute() is made + * before or after this API call. However, ::cuKernelGetAttribute() will always + * return the attribute value set by this API. + * + * Supported attributes are: + * - ::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES: This is the maximum size in bytes of + * dynamically-allocated shared memory. The value should contain the requested + * maximum size of dynamically-allocated shared memory. The sum of this value and + * the function attribute ::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES cannot exceed the + * device attribute ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN. + * The maximal size of requestable dynamic shared memory may differ by GPU + * architecture. + * - ::CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT: On devices where the L1 + * cache and shared memory use the same hardware resources, this sets the shared memory + * carveout preference, in percent of the total shared memory. + * See ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR + * This is only a hint, and the driver can choose a different ratio if required to execute the function. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_WIDTH: The required cluster width in + * blocks. The width, height, and depth values must either all be 0 or all be + * positive. The validity of the cluster dimensions is checked at launch time. + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime will return CUDA_ERROR_NOT_PERMITTED. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_HEIGHT: The required cluster height in + * blocks. The width, height, and depth values must either all be 0 or all be + * positive. The validity of the cluster dimensions is checked at launch time. + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime will return CUDA_ERROR_NOT_PERMITTED. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_DEPTH: The required cluster depth in + * blocks. The width, height, and depth values must either all be 0 or all be + * positive. The validity of the cluster dimensions is checked at launch time. + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime will return CUDA_ERROR_NOT_PERMITTED. + * - ::CU_FUNC_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE: The block + * scheduling policy of a function. The value type is CUclusterSchedulingPolicy. + * + * \note The API has stricter locking requirements in comparison to its legacy counterpart + * ::cuFuncSetAttribute() due to device-wide semantics. If multiple threads are trying to + * set the same attribute on the same device simultaneously, the attribute setting will depend + * on the interleavings chosen by the OS scheduler and memory consistency. + * + * \param attrib - Attribute requested + * \param val - Value to set + * \param kernel - Kernel to set attribute of + * \param dev - Device to set attribute of + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + * ::cuKernelGetAttribute, + * ::cuLibraryGetKernel, + * ::cuLaunchKernel, + * ::cuKernelGetFunction, + * ::cuLibraryGetModule, + * ::cuModuleGetFunction, + * ::cuFuncSetAttribute + */ +CUresult CUDAAPI cuKernelSetAttribute(CUfunction_attribute attrib, int val, CUkernel kernel, CUdevice dev); + +/** + * \brief Sets the preferred cache configuration for a device kernel. + * + * On devices where the L1 cache and shared memory use the same hardware + * resources, this sets through \p config the preferred cache configuration for + * the device kernel \p kernel on the requested device \p dev. This is only a preference. + * The driver will use the requested configuration if possible, but it is free to choose a different + * configuration if required to execute \p kernel. Any context-wide preference + * set via ::cuCtxSetCacheConfig() will be overridden by this per-kernel + * setting. + * + * Note that attributes set using ::cuFuncSetCacheConfig() will override the attribute + * set by this API irrespective of whether the call to ::cuFuncSetCacheConfig() is made + * before or after this API call. + * + * This setting does nothing on devices where the size of the L1 cache and + * shared memory are fixed. + * + * Launching a kernel with a different preference than the most recent + * preference setting may insert a device-side synchronization point. + * + * + * The supported cache configurations are: + * - ::CU_FUNC_CACHE_PREFER_NONE: no preference for shared memory or L1 (default) + * - ::CU_FUNC_CACHE_PREFER_SHARED: prefer larger shared memory and smaller L1 cache + * - ::CU_FUNC_CACHE_PREFER_L1: prefer larger L1 cache and smaller shared memory + * - ::CU_FUNC_CACHE_PREFER_EQUAL: prefer equal sized L1 cache and shared memory + * + * \note The API has stricter locking requirements in comparison to its legacy counterpart + * ::cuFuncSetCacheConfig() due to device-wide semantics. If multiple threads are trying to + * set a config on the same device simultaneously, the cache config setting will depend + * on the interleavings chosen by the OS scheduler and memory consistency. + * + * \param kernel - Kernel to configure cache for + * \param config - Requested cache configuration + * \param dev - Device to set attribute of + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuLibraryLoadData, + * ::cuLibraryLoadFromFile, + * ::cuLibraryUnload, + * ::cuLibraryGetKernel, + * ::cuKernelGetFunction, + * ::cuLibraryGetModule, + * ::cuModuleGetFunction, + * ::cuFuncSetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuLaunchKernel + */ +CUresult CUDAAPI cuKernelSetCacheConfig(CUkernel kernel, CUfunc_cache config, CUdevice dev); + +/** @} */ /* END CUDA_LIBRARY */ /** * \defgroup CUDA_MEM Memory Management @@ -6322,8 +7547,14 @@ CUresult CUDAAPI cuMemAllocPitch(CUdeviceptr *dptr, size_t *pPitch, size_t Width * \brief Frees device memory * * Frees the memory space pointed to by \p dptr, which must have been returned - * by a previous call to ::cuMemAlloc() or ::cuMemAllocPitch(). + * by a previous call to one of the following memory allocation APIs - ::cuMemAlloc(), + * ::cuMemAllocPitch(), ::cuMemAllocManaged(), ::cuMemAllocAsync(), ::cuMemAllocFromPoolAsync() * + * Note - This API will not perform any implict synchronization when the pointer was allocated with + * ::cuMemAllocAsync or ::cuMemAllocFromPoolAsync. Callers must ensure that all accesses to the + * pointer have completed before invoking ::cuMemFree. For best performance and memory reuse, users + * should use ::cuMemFreeAsync to free memory allocated via the stream ordered memory allocator. + * * \param dptr - Pointer to memory to free * * \return @@ -6336,12 +7567,12 @@ CUresult CUDAAPI cuMemAllocPitch(CUdeviceptr *dptr, size_t *pPitch, size_t Width * * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, - * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, - * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, - * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, - * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, - * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFreeHost, - * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemAllocPitch, ::cuMemAllocManaged, ::cuMemAllocAsync, ::cuMemAllocFromPoolAsync, + * ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, ::cuMemcpy3D, ::cuMemcpy3DAsync, + * ::cuMemcpyAtoA, ::cuMemcpyAtoD, ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, + * ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, + * ::cuMemcpyHtoAAsync, ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, ::cuMemFreeAsync, * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, * ::cudaFree @@ -6640,7 +7871,7 @@ CUresult CUDAAPI cuMemHostGetFlags(unsigned int *pFlags, void *p); * ::cuStreamAttachMemAsync will be required to enable access on such devices. * * If the association is later changed via ::cuStreamAttachMemAsync to - * a single stream, the default association as specifed during ::cuMemAllocManaged + * a single stream, the default association as specified during ::cuMemAllocManaged * is restored when that stream is destroyed. For __managed__ variables, the * default association is always ::CU_MEM_ATTACH_GLOBAL. Note that destroying a * stream is an asynchronous operation, and as a result, the change to default @@ -6807,6 +8038,8 @@ CUresult CUDAAPI cuDeviceGetPCIBusId(char *pciBusId, int len, CUdevice dev); * IPC functionality is restricted to devices with support for unified * addressing on Linux and Windows operating systems. * IPC functionality on Windows is restricted to GPUs in TCC mode + * Users can test their device for IPC functionality by calling + * ::cuapiDeviceGetAttribute with ::CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED * * \param pHandle - Pointer to a user allocated CUipcEventHandle * in which to return the opaque event handle @@ -6848,6 +8081,8 @@ CUresult CUDAAPI cuIpcGetEventHandle(CUipcEventHandle *pHandle, CUevent event); * IPC functionality is restricted to devices with support for unified * addressing on Linux and Windows operating systems. * IPC functionality on Windows is restricted to GPUs in TCC mode + * Users can test their device for IPC functionality by calling + * ::cuapiDeviceGetAttribute with ::CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED * * \param phEvent - Returns the imported event * \param handle - Interprocess handle to open @@ -6891,6 +8126,8 @@ CUresult CUDAAPI cuIpcOpenEventHandle(CUevent *phEvent, CUipcEventHandle handle) * IPC functionality is restricted to devices with support for unified * addressing on Linux and Windows operating systems. * IPC functionality on Windows is restricted to GPUs in TCC mode + * Users can test their device for IPC functionality by calling + * ::cuapiDeviceGetAttribute with ::CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED * * \param pHandle - Pointer to user allocated ::CUipcMemHandle to return * the handle in. @@ -6944,6 +8181,8 @@ CUresult CUDAAPI cuIpcGetMemHandle(CUipcMemHandle *pHandle, CUdeviceptr dptr); * IPC functionality is restricted to devices with support for unified * addressing on Linux and Windows operating systems. * IPC functionality on Windows is restricted to GPUs in TCC mode + * Users can test their device for IPC functionality by calling + * ::cuapiDeviceGetAttribute with ::CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED * * \param pdptr - Returned device pointer * \param handle - ::CUipcMemHandle to open @@ -6987,6 +8226,8 @@ CUresult CUDAAPI cuIpcOpenMemHandle(CUdeviceptr *pdptr, CUipcMemHandle handle, u * IPC functionality is restricted to devices with support for unified * addressing on Linux and Windows operating systems. * IPC functionality on Windows is restricted to GPUs in TCC mode + * Users can test their device for IPC functionality by calling + * ::cuapiDeviceGetAttribute with ::CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED * * \param dptr - Device pointer returned by ::cuIpcOpenMemHandle * @@ -7021,7 +8262,9 @@ CUresult CUDAAPI cuIpcCloseMemHandle(CUdeviceptr dptr); * best used sparingly to register staging areas for data exchange between * host and device. * - * This function has limited support on Mac OS X. OS 10.7 or higher is required. + * On systems where ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES + * is true, ::cuMemHostRegister will not page-lock the memory range specified + * by \p ptr but only populate unpopulated pages. * * The \p Flags parameter enables different options to be specified that * affect the allocation, as follows. @@ -9313,7 +10556,6 @@ CUresult CUDAAPI cuArrayGetSparseProperties(CUDA_ARRAY_SPARSE_PROPERTIES *sparse */ CUresult CUDAAPI cuMipmappedArrayGetSparseProperties(CUDA_ARRAY_SPARSE_PROPERTIES *sparseProperties, CUmipmappedArray mipmap); - /** * \brief Returns the memory requirements of a CUDA array * @@ -9361,7 +10603,6 @@ CUresult CUDAAPI cuArrayGetMemoryRequirements(CUDA_ARRAY_MEMORY_REQUIREMENTS *me */ CUresult CUDAAPI cuMipmappedArrayGetMemoryRequirements(CUDA_ARRAY_MEMORY_REQUIREMENTS *memoryRequirements, CUmipmappedArray mipmap, CUdevice device); - /** * \brief Gets a CUDA array plane from a CUDA array * @@ -9391,7 +10632,7 @@ CUresult CUDAAPI cuMipmappedArrayGetMemoryRequirements(CUDA_ARRAY_MEMORY_REQUIRE * * \sa * ::cuArrayCreate, - * ::cudaGetArrayPlane + * ::cudaArrayGetPlane */ CUresult CUDAAPI cuArrayGetPlane(CUarray *pPlaneArray, CUarray hArray, unsigned int planeIdx); @@ -9841,6 +11082,39 @@ CUresult CUDAAPI cuMipmappedArrayGetLevel(CUarray *pLevelArray, CUmipmappedArray */ CUresult CUDAAPI cuMipmappedArrayDestroy(CUmipmappedArray hMipmappedArray); +/** +* \brief Retrieve handle for an address range +* +* Get a handle of the specified type to an address range. The address range +* must have been obtained by a prior call to either ::cuMemAlloc or ::cuMemAddressReserve. +* If the address range was obtained via ::cuMemAddressReserve, it must also be fully mapped via ::cuMemMap. +* +* Users must ensure the \p dptr and \p size are aligned to the host page size. +* +* When requesting CUmemRangeHandleType::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, +* users are expected to query for dma_buf support for the platform +* by using ::CU_DEVICE_ATTRIBUTE_DMA_BUF_SUPPORTED device attribute before calling +* this API. The \p handle will be interpreted as a pointer to an integer to store the dma_buf file descriptor. +* Users must ensure the entire address range is backed and mapped when +* the address range is allocated by ::cuMemAddressReserve. All the physical +* allocations backing the address range must be resident on the same device and +* have identical allocation properties. Users are also expected to retrieve a +* new handle every time the underlying physical allocation(s) corresponding +* to a previously queried VA range are changed. +* +* \param[out] handle - Pointer to the location where the returned handle will be stored. +* \param[in] dptr - Pointer to a valid CUDA device allocation. Must be aligned to host page size. +* \param[in] size - Length of the address range. Must be aligned to host page size. +* \param[in] handleType - Type of handle requested (defines type and size of the \p handle output parameter) +* \param[in] flags - Reserved, must be zero +* +* \return +* CUDA_SUCCESS +* CUDA_ERROR_INVALID_VALUE +* CUDA_ERROR_NOT_SUPPORTED +*/ +CUresult CUDAAPI cuMemGetHandleForAddressRange(void *handle, CUdeviceptr dptr, size_t size, CUmemRangeHandleType handleType, unsigned long long flags); + /** @} */ /* END CUDA_MEM */ /** @@ -9907,13 +11181,13 @@ CUresult CUDAAPI cuMemAddressFree(CUdeviceptr ptr, size_t size); * \brief Create a CUDA memory handle representing a memory allocation of a given size described by the given properties * * This creates a memory allocation on the target device specified through the -* \p prop strcuture. The created allocation will not have any device or host +* \p prop structure. The created allocation will not have any device or host * mappings. The generic memory \p handle for the allocation can be * mapped to the address space of calling process via ::cuMemMap. This handle * cannot be transmitted directly to other processes (see * ::cuMemExportToShareableHandle). On Windows, the caller must also pass * an LPSECURITYATTRIBUTE in \p prop to be associated with this handle which -* limits or allows access to this handle for a recepient process (see +* limits or allows access to this handle for a recipient process (see * ::CUmemAllocationProp::win32HandleMetaData for more). The \p size of this * allocation must be a multiple of the the value given via * ::cuMemGetAllocationGranularity with the ::CU_MEM_ALLOC_GRANULARITY_MINIMUM @@ -9951,7 +11225,7 @@ CUresult CUDAAPI cuMemCreate(CUmemGenericAllocationHandle *handle, size_t size, * are unmapped and when all outstanding references to the handle (including it's * shareable counterparts) are also released. The generic memory handle can be * freed when there are still outstanding mappings made with this handle. Each -* time a recepient process imports a shareable handle, it needs to pair it with +* time a recipient process imports a shareable handle, it needs to pair it with * ::cuMemRelease for the handle to be freed. If \p handle is not a valid handle * the behavior is undefined. * @@ -9978,6 +11252,12 @@ CUresult CUDAAPI cuMemRelease(CUmemGenericAllocationHandle handle); * \p offset + \p size must be less than the size of the memory allocation. * Both \p ptr, \p size, and \p offset must be a multiple of the value given via * ::cuMemGetAllocationGranularity with the ::CU_MEM_ALLOC_GRANULARITY_MINIMUM flag. +* If \p handle represents a multicast object, \p ptr, \p size and \p offset must +* be aligned to the value returned by ::cuMulticastGetGranularity with the flag +* ::CU_MULTICAST_MINIMUM_GRANULARITY. For best performance however, it is +* recommended that \p ptr, \p size and \p offset be aligned to the value +* returned by ::cuMulticastGetGranularity with the flag +* ::CU_MULTICAST_RECOMMENDED_GRANULARITY. * * Please note calling ::cuMemMap does not make the address accessible, * the caller needs to update accessibility of a contiguous mapped VA @@ -10065,17 +11345,13 @@ CUresult CUDAAPI cuMemMap(CUdeviceptr ptr, size_t size, size_t offset, CUmemGene * ::CUarrayMapInfo::resource::array must be set to a valid sparse CUDA array handle. * The CUDA array must be either a 2D, 2D layered or 3D CUDA array and must have been allocated using * ::cuArrayCreate or ::cuArray3DCreate with the flag ::CUDA_ARRAY3D_SPARSE - * or ::CUDA_ARRAY3D_DEFERRED_MAPPING. - * For CUDA arrays obtained using ::cuMipmappedArrayGetLevel, ::CUDA_ERROR_INVALID_VALUE will be returned. * If ::CUarrayMapInfo::resourceType is set to ::CUresourcetype::CU_RESOURCE_TYPE_MIPMAPPED_ARRAY * then ::CUarrayMapInfo::resource::mipmap must be set to a valid sparse CUDA mipmapped array handle. * The CUDA mipmapped array must be either a 2D, 2D layered or 3D CUDA mipmapped array and must have been * allocated using ::cuMipmappedArrayCreate with the flag ::CUDA_ARRAY3D_SPARSE - * or ::CUDA_ARRAY3D_DEFERRED_MAPPING. - * * ::CUarrayMapInfo::subresourceType specifies the type of subresource within the resource. * ::CUarraySparseSubresourceType_enum is defined as: @@ -10114,11 +11390,9 @@ CUresult CUDAAPI cuMemMap(CUdeviceptr ptr, size_t size, size_t offset, CUmemGene * as returned by ::cuMipmappedArrayGetSparseProperties, ::CUarrayMapInfo::subresource::miptail::layer must specify a valid layer index. * Otherwise, must be zero. * - * If ::CUarrayMapInfo::resource::array or ::CUarrayMapInfo::resource::mipmap was created with ::CUDA_ARRAY3D_DEFERRED_MAPPING * flag set the ::CUarrayMapInfo::subresourceType and the contents of ::CUarrayMapInfo::subresource will be ignored. * - * ::CUarrayMapInfo::memOperationType specifies the type of operation. ::CUmemOperationType is defined as: \code typedef enum CUmemOperationType_enum { @@ -10190,6 +11464,12 @@ CUresult CUDAAPI cuMemUnmap(CUdeviceptr ptr, size_t size); * in the array given by \p desc and \p count, set the access flags for the * target locations. The range must be a fully mapped address range * containing all allocations created by ::cuMemMap / ::cuMemCreate. +* When setting the access flags for a virtual address range mapping a multicast +* object, \p ptr and \p size must be aligned to the value returned by +* ::cuMulticastGetGranularity with the flag ::CU_MULTICAST_MINIMUM_GRANULARITY. +* For best performance however, it is recommended that \p ptr and \p size be +* aligned to the value returned by ::cuMulticastGetGranularity with the flag +* ::CU_MULTICAST_RECOMMENDED_GRANULARITY. * * \param[in] ptr - Starting address for the virtual address range * \param[in] size - Length of the virtual address range @@ -10778,6 +12058,264 @@ CUresult CUDAAPI cuMemPoolImportPointer(CUdeviceptr *ptr_out, CUmemoryPool pool, /** @} */ /* END CUDA_MALLOC_ASYNC */ +/** + * \defgroup CUDA_MULTICAST Multicast Object Management + * + * ___MANBRIEF___ Functions for creating multicast objects, adding devices to them and binding/unbinding memory + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the CUDA multicast object operations exposed by the + * low-level CUDA driver application programming interface. + * + * @{ + * + * \section CUDA_MULTICAST_overview overview + * + * A multicast object created via ::cuMulticastCreate enables certain memory + * operations to be broadcasted to a team of devices. Devices can be added to a + * multicast object via ::cuMulticastAddDevice. Memory can be bound on each + * participating device via either ::cuMulticastBindMem or ::cuMulticastBindAddr. + * Multicast objects can be mapped into a device's virtual address space using + * the virtual memmory management APIs (see ::cuMemMap and ::cuMemSetAccess). + * + * \section CUDA_MULTICAST_support Supported Platforms + * + * Support for multicast on a specific device can be queried using the device + * attribute ::CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED + */ + +/** + * \brief Create a generic allocation handle representing a multicast object described by the given properties. + * + * This creates a multicast object as described by \p prop. The number of + * participating devices is specified by ::CUmulticastObjectProp::numDevices. + * Devices can be added to the multicast object via ::cuMulticastAddDevice. + * All participating devices must be added to the multicast object before memory + * can be bound to it. Memory is bound to the multicast object via either + * ::cuMulticastBindMem or ::cuMulticastBindAddr, and can be unbound via + * ::cuMulticastUnbind. The total amount of memory that can be bound per device + * is specified by :CUmulticastObjectProp::size. This size must be a multiple of + * the value returned by ::cuMulticastGetGranularity with the flag + * ::CU_MULTICAST_GRANULARITY_MINIMUM. For best performance however, the size + * should be aligned to the value returned by ::cuMulticastGetGranularity with + * the flag ::CU_MULTICAST_GRANULARITY_RECOMMENDED. + * + * After all participating devices have been added, multicast objects can also + * be mapped to a device's virtual address space using the virtual memory + * management APIs (see ::cuMemMap and ::cuMemSetAccess). Multicast objects can + * also be shared with other processes by requesting a shareable handle via + * ::cuMemExportToShareableHandle. Note that the desired types of shareable + * handles must be specified in the bitmask ::CUmulticastObjectProp::handleTypes. + * Multicast objects can be released using the virtual memory management API + * ::cuMemRelease. + * + * \param[out] mcHandle Value of handle returned. + * \param[in] prop Properties of the multicast object to create. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuMulticastAddDevice, ::cuMulticastBindMem, ::cuMulticastBindAddr, ::cuMulticastUnbind + * \sa ::cuMemCreate, ::cuMemRelease, ::cuMemExportToShareableHandle, ::cuMemImportFromShareableHandle + */ +CUresult CUDAAPI cuMulticastCreate(CUmemGenericAllocationHandle *mcHandle, const CUmulticastObjectProp *prop); + +/** + * \brief Associate a device to a multicast object. + * + * Associates a device to a multicast object. The added device will be a part of + * the multicast team of size specified by CUmulticastObjectProp::numDevices + * during ::cuMulticastCreate. + * The association of the device to the multicast object is permanent during + * the life time of the multicast object. + * All devices must be added to the multicast team before any memory can be + * bound to any device in the team. Any calls to ::cuMulticastBindMem or + * ::cuMulticastBindAddr will block until all devices have been added. + * Similarly all devices must be added to the multicast team before a virtual + * address range can be mapped to the multicast object. A call to ::cuMemMap + * will block until all devices have been added. + * + * \param[in] mcHandle Handle representing a multicast object. + * \param[in] dev Device that will be associated to the multicast + * object. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuMulticastCreate, ::cuMulticastBindMem, ::cuMulticastBindAddr + */ +CUresult CUDAAPI cuMulticastAddDevice(CUmemGenericAllocationHandle mcHandle, CUdevice dev); + +/** + * \brief Bind a memory allocation represented by a handle to a multicast object. + * + * Binds a memory allocation specified by \p memHandle and created via + * ::cuMemCreate to a multicast object represented by \p mcHandle and created + * via ::cuMulticastCreate. The intended \p size of the bind, the offset in the + * multicast range \p mcOffset as well as the offset in the memory \p memOffset + * must be a multiple of the value returned by ::cuMulticastGetGranularity with + * the flag ::CU_MULTICAST_GRANULARITY_MINIMUM. For best performance however, + * \p size, \p mcOffset and \p memOffset should be aligned to the granularity of + * the memory allocation(see ::cuMemGetAllocationGranularity) or to the value + * returned by ::cuMulticastGetGranularity with the flag + * ::CU_MULTICAST_GRANULARITY_RECOMMENDED. + * + * The \p size + \p memOffset must be smaller than the size of the allocated + * memory. Similarly the \p size + \p mcOffset must be smaller than the size + * of the multicast object. + * The memory allocation must have beeen created on one of the devices + * that was added to the multicast team via ::cuMulticastAddDevice. + * Externally shareable as well as imported multicast objects can be bound only + * to externally shareable memory. + * Note that this call will return CUDA_ERROR_OUT_OF_MEMORY if there are + * insufficient resources required to perform the bind. This call may also + * return CUDA_ERROR_SYSTEM_NOT_READY if the necessary system software is not + * initialized or running. + * + * \param[in] mcHandle Handle representing a multicast object. + * \param[in] mcOffset Offset into the multicast object for attachment. + * \param[in] memHandle Handle representing a memory allocation. + * \param[in] memOffset Offset into the memory for attachment. + * \param[in] size Size of the memory that will be bound to the + * multicast object. + * \param[in] flags Flags for future use, must be zero for now. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_SYSTEM_NOT_READY + * + * \sa ::cuMulticastCreate, ::cuMulticastAddDevice, ::cuMemCreate + */ +CUresult CUDAAPI cuMulticastBindMem(CUmemGenericAllocationHandle mcHandle, size_t mcOffset, CUmemGenericAllocationHandle memHandle, size_t memOffset, size_t size, unsigned long long flags); + +/** + * \brief Bind a memory allocation represented by a virtual address to a multicast object. + * + * Binds a memory allocation specified by its mapped address \p memptr to a + * multicast object represented by \p mcHandle. + * The memory must have been allocated via ::cuMemCreate or ::cudaMallocAsync. + * The intended \p size of the bind, the offset in the multicast range + * \p mcOffset and \p memptr must be a multiple of the value returned by + * ::cuMulticastGetGranularity with the flag ::CU_MULTICAST_GRANULARITY_MINIMUM. + * For best performance however, \p size, \p mcOffset and \p memptr should be + * aligned to the value returned by ::cuMulticastGetGranularity with the flag + * ::CU_MULTICAST_GRANULARITY_RECOMMENDED. + * + * The \p size must be smaller than the size of the allocated memory. + * Similarly the \p size + \p mcOffset must be smaller than the total size + * of the multicast object. + * The memory allocation must have beeen created on one of the devices + * that was added to the multicast team via ::cuMulticastAddDevice. + * Externally shareable as well as imported multicast objects can be bound only + * to externally shareable memory. + * Note that this call will return CUDA_ERROR_OUT_OF_MEMORY if there are + * insufficient resources required to perform the bind. This call may also + * return CUDA_ERROR_SYSTEM_NOT_READY if the necessary system software is not + * initialized or running. + * + * \param[in] mcHandle Handle representing a multicast object. + * \param[in] mcOffset Offset into multicast va range for attachment. + * \param[in] memptr Virtual address of the memory allocation. + * \param[in] size Size of memory that will be bound to the + * multicast object. + * \param[in] flags Flags for future use, must be zero now. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_SYSTEM_NOT_READY + * + * \sa ::cuMulticastCreate, ::cuMulticastAddDevice, ::cuMemCreate + */ +CUresult CUDAAPI cuMulticastBindAddr(CUmemGenericAllocationHandle mcHandle, size_t mcOffset, CUdeviceptr memptr, size_t size, unsigned long long flags); + +/** + * \brief Unbind any memory allocations bound to a multicast object at a given offset and upto a given size. + * + * Unbinds any memory allocations hosted on \p dev and bound to a multicast + * object at \p mcOffset and upto a given \p size. + * The intended \p size of the unbind and the offset in the multicast range + * ( \p mcOffset ) must be a multiple of the value returned by + * ::cuMulticastGetGranularity flag ::CU_MULTICAST_GRANULARITY_MINIMUM. + * The \p size + \p mcOffset must be smaller than the total size of the + * multicast object. + * + * \note + * Warning: + * The \p mcOffset and the \p size must match the corresponding values specified + * during the bind call. Any other values may result in undefined behavior. + * + * \param[in] mcHandle Handle representing a multicast object. + * \param[in] dev Device that hosts the memory allocation. + * \param[in] mcOffset Offset into the multicast object. + * \param[in] size Desired size to unbind. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuMulticastBindMem, ::cuMulticastBindAddr + */ +CUresult CUDAAPI cuMulticastUnbind(CUmemGenericAllocationHandle mcHandle, CUdevice dev, size_t mcOffset, size_t size); + +/** +* \brief Calculates either the minimal or recommended granularity for multicast object +* +* Calculates either the minimal or recommended granularity for a given set of +* multicast object properties and returns it in granularity. This granularity +* can be used as a multiple for size, bind offsets and address mappings of the +* multicast object. +* +* \param[out] granularity Returned granularity. +* \param[in] prop Properties of the multicast object. +* \param[in] option Determines which granularity to return. +* +* \returns +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* +* \sa ::cuMulticastCreate, ::cuMulticastBindMem, ::cuMulticastBindAddr, ::cuMulticastUnbind +*/ +CUresult CUDAAPI cuMulticastGetGranularity(size_t *granularity, const CUmulticastObjectProp *prop, CUmulticastGranularity_flags option); + +/** @} */ /* END CUDA_MULTICAST */ + /** * \defgroup CUDA_UNIFIED Unified Addressing * @@ -11277,7 +12815,7 @@ CUresult CUDAAPI cuMemAdvise(CUdeviceptr devPtr, size_t count, CUmem_advise advi * a GPU id or CU_DEVICE_CPU depending on whether the last location for prefetch was a GPU or the CPU * respectively. If any page in the memory range was never explicitly prefetched or if all pages were not * prefetched to the same location, CU_DEVICE_INVALID will be returned. Note that this simply returns the - * last location that the applicaton requested to prefetch the memory range to. It gives no indication as to + * last location that the application requested to prefetch the memory range to. It gives no indication as to * whether the prefetch operation to that location has completed or even begun. * * \param data - A pointers to a memory location where the result @@ -11591,6 +13129,39 @@ CUresult CUDAAPI cuStreamGetPriority(CUstream hStream, int *priority); */ CUresult CUDAAPI cuStreamGetFlags(CUstream hStream, unsigned int *flags); +/** + * \brief Returns the unique Id associated with the stream handle supplied + * + * Returns in \p streamId the unique Id which is associated with the given stream handle. + * The Id is unique for the life of the program. + * + * The stream handle \p hStream can refer to any of the following: + *
    + *
  • a stream created via any of the CUDA driver APIs such as ::cuStreamCreate + * and ::cuStreamCreateWithPriority, or their runtime API equivalents such as + * ::cudaStreamCreate, ::cudaStreamCreateWithFlags and ::cudaStreamCreateWithPriority. + * Passing an invalid handle will result in undefined behavior.
  • + *
  • any of the special streams such as the NULL stream, ::CU_STREAM_LEGACY and + * ::CU_STREAM_PER_THREAD. The runtime API equivalents of these are also accepted, + * which are NULL, ::cudaStreamLegacy and ::cudaStreamPerThread respectively.
  • + *
+ * + * \param hStream - Handle to the stream to be queried + * \param streamId - Pointer to store the Id of the stream + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuStreamDestroy, + * ::cuStreamCreate, + * ::cuStreamGetPriority, + * ::cudaStreamGetId + */ +CUresult CUDAAPI cuStreamGetId(CUstream hStream, unsigned long long *streamId); + /** * \brief Query the context associated with a stream * @@ -11742,7 +13313,7 @@ CUresult CUDAAPI cuStreamWaitEvent(CUstream hStream, CUevent hEvent, unsigned in * ::cuStreamDestroy, * ::cuMemAllocManaged, * ::cuStreamAttachMemAsync, - * ::cuStreamLaunchHostFunc, + * ::cuLaunchHostFunc, * ::cudaStreamAddCallback */ CUresult CUDAAPI cuStreamAddCallback(CUstream hStream, CUstreamCallback callback, void *userData, unsigned int flags); @@ -11909,35 +13480,7 @@ CUresult CUDAAPI cuStreamEndCapture(CUstream hStream, CUgraph *phGraph); CUresult CUDAAPI cuStreamIsCapturing(CUstream hStream, CUstreamCaptureStatus *captureStatus); /** - * \brief Query capture status of a stream - * - * Note there is a later version of this API, ::cuStreamGetCaptureInfo_v2. It will - * supplant this version in 12.0, which is retained for minor version compatibility. - * - * Query the capture status of a stream and and get an id for - * the capture sequence, which is unique over the lifetime of the process. - * - * If called on ::CU_STREAM_LEGACY (the "null stream") while a stream not created - * with ::CU_STREAM_NON_BLOCKING is capturing, returns ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT. - * - * A valid id is returned only if both of the following are true: - * - the call returns CUDA_SUCCESS - * - captureStatus is set to ::CU_STREAM_CAPTURE_STATUS_ACTIVE - * - * \return - * ::CUDA_SUCCESS, - * ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT - * \notefnerr - * - * \sa - * ::cuStreamGetCaptureInfo_v2, - * ::cuStreamBeginCapture, - * ::cuStreamIsCapturing - */ -CUresult CUDAAPI cuStreamGetCaptureInfo(CUstream hStream, CUstreamCaptureStatus *captureStatus_out, cuuint64_t *id_out); - -/** - * \brief Query a stream's capture state (11.3+) + * \brief Query a stream's capture state * * Query stream state related to stream capture. * @@ -11948,11 +13491,6 @@ CUresult CUDAAPI cuStreamGetCaptureInfo(CUstream hStream, CUstreamCaptureStatus * - the call returns CUDA_SUCCESS * - the returned capture status is ::CU_STREAM_CAPTURE_STATUS_ACTIVE * - * This version of cuStreamGetCaptureInfo is introduced in CUDA 11.3 and will supplant the - * previous version in 12.0. Developers requiring compatibility across minor versions to - * CUDA 11.0 (driver version 445) should use ::cuStreamGetCaptureInfo or include a fallback - * path. - * * \param hStream - The stream to query * \param captureStatus_out - Location to return the capture status of the stream; required * \param id_out - Optional location to return an id for the capture sequence, which is @@ -11983,12 +13521,11 @@ CUresult CUDAAPI cuStreamGetCaptureInfo(CUstream hStream, CUstreamCaptureStatus * \notefnerr * * \sa - * ::cuStreamGetCaptureInfo, * ::cuStreamBeginCapture, * ::cuStreamIsCapturing, * ::cuStreamUpdateCaptureDependencies */ -CUresult CUDAAPI cuStreamGetCaptureInfo_v2(CUstream hStream, CUstreamCaptureStatus *captureStatus_out, +CUresult CUDAAPI cuStreamGetCaptureInfo(CUstream hStream, CUstreamCaptureStatus *captureStatus_out, cuuint64_t *id_out, CUgraph *graph_out, const CUgraphNode **dependencies_out, size_t *numDependencies_out); /** @@ -12019,7 +13556,6 @@ CUresult CUDAAPI cuStreamGetCaptureInfo_v2(CUstream hStream, CUstreamCaptureStat * \sa * ::cuStreamBeginCapture, * ::cuStreamGetCaptureInfo, - * ::cuStreamGetCaptureInfo_v2 */ CUresult CUDAAPI cuStreamUpdateCaptureDependencies(CUstream hStream, CUgraphNode *dependencies, size_t numDependencies, unsigned int flags); @@ -12536,7 +14072,8 @@ CUresult CUDAAPI cuEventDestroy(CUevent hEvent); * ::CUDA_ERROR_NOT_INITIALIZED, * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_HANDLE, - * ::CUDA_ERROR_NOT_READY + * ::CUDA_ERROR_NOT_READY, + * ::CUDA_ERROR_UNKNOWN * \notefnerr * * \sa ::cuEventCreate, @@ -12711,7 +14248,8 @@ CUresult CUDAAPI cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUeven * ::CUDA_SUCCESS, * ::CUDA_ERROR_NOT_INITIALIZED, * ::CUDA_ERROR_INVALID_VALUE, - * ::CUDA_ERROR_INVALID_HANDLE + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OPERATING_SYSTEM * \notefnerr * * \note If the Vulkan memory imported into CUDA is mapped on the CPU then the @@ -13003,7 +14541,8 @@ CUresult CUDAAPI cuDestroyExternalMemory(CUexternalMemory extMem); * ::CUDA_SUCCESS, * ::CUDA_ERROR_NOT_INITIALIZED, * ::CUDA_ERROR_NOT_SUPPORTED, - * ::CUDA_ERROR_INVALID_HANDLE + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OPERATING_SYSTEM * \notefnerr * * \sa ::cuDestroyExternalSemaphore, @@ -13054,6 +14593,21 @@ CUresult CUDAAPI cuImportExternalSemaphore(CUexternalSemaphore *extSem_out, cons * if the NvSciSyncAttrList used to create the NvSciSyncObj had not set the flags in * ::cuDeviceGetNvSciSyncAttributes to CUDA_NVSCISYNC_ATTR_SIGNAL, this API will return * CUDA_ERROR_NOT_SUPPORTED. + * NvSciSyncFence associated with semaphore object of the type + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC can be deterministic. For this the + * NvSciSyncAttrList used to create the semaphore object must have value of + * NvSciSyncAttrKey_RequireDeterministicFences key set to true. Deterministic fences + * allow users to enqueue a wait over the semaphore object even before corresponding + * signal is enqueued. For such a semaphore object, CUDA guarantees that each signal + * operation will increment the fence value by '1'. Users are expected to track count + * of signals enqueued on the semaphore object and insert waits accordingly. When such + * a semaphore object is signaled from multiple streams, due to concurrent stream + * execution, it is possible that the order in which the semaphore gets signaled is + * indeterministic. This could lead to waiters of the semaphore getting unblocked + * incorrectly. Users are expected to handle such situations, either by not using the + * same semaphore object with deterministic fence support enabled in different streams + * or by adding explicit dependency amongst such streams so that the semaphore is + * signaled in order. * * If the semaphore object is any one of the following types: * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX, @@ -13179,7 +14733,7 @@ CUresult CUDAAPI cuDestroyExternalSemaphore(CUexternalSemaphore extSem); /** @} */ /* END CUDA_EXTRES_INTEROP */ /** - * \defgroup CUDA_MEMOP Stream memory operations + * \defgroup CUDA_MEMOP Stream Memory Operations * * ___MANBRIEF___ Stream memory operations of the low-level CUDA driver API * (___CURRENT_FILE___) ___ENDMANBRIEF___ @@ -13187,19 +14741,8 @@ CUresult CUDAAPI cuDestroyExternalSemaphore(CUexternalSemaphore extSem); * This section describes the stream memory operations of the low-level CUDA * driver application programming interface. * - * The whole set of operations is disabled by default. Users are required - * to explicitly enable them, e.g. on Linux by passing the kernel module - * parameter shown below: - * modprobe nvidia NVreg_EnableStreamMemOPs=1 - * There is currently no way to enable these operations on other operating - * systems. - * - * Users can programmatically query whether the device supports these - * operations with ::cuDeviceGetAttribute() and - * ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_MEM_OPS. - * * Support for the ::CU_STREAM_WAIT_VALUE_NOR flag can be queried with - * ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR. + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR_V2. * * Support for the ::cuStreamWriteValue64() and ::cuStreamWaitValue64() * functions, as well as for the ::CU_STREAM_MEM_OP_WAIT_VALUE_64 and @@ -13218,6 +14761,14 @@ CUresult CUDAAPI cuDestroyExternalSemaphore(CUexternalSemaphore extSem); * None of the operations accepts pointers to managed memory buffers * (::cuMemAllocManaged). * + * \note + * Warning: + * Improper use of these APIs may deadlock the application. Synchronization + * ordering established through these APIs is not visible to CUDA. CUDA tasks + * that are (even indirectly) ordered by these APIs should also have that order + * expressed with CUDA-visible dependencies such as events. This ensures that + * the scheduler does not serialize them in an improper order. + * * @{ */ @@ -13234,11 +14785,16 @@ CUresult CUDAAPI cuDestroyExternalSemaphore(CUexternalSemaphore extSem); * should be obtained with ::cuMemHostGetDevicePointer(). This function cannot * be used with managed memory (::cuMemAllocManaged). * - * Support for this can be queried with ::cuDeviceGetAttribute() and - * ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_MEM_OPS. - * * Support for CU_STREAM_WAIT_VALUE_NOR can be queried with ::cuDeviceGetAttribute() and - * ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR. + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR_V2. + * + * \note + * Warning: + * Improper use of this API may deadlock the application. Synchronization + * ordering established through this API is not visible to CUDA. CUDA tasks + * that are (even indirectly) ordered by this API should also have that order + * expressed with CUDA-visible dependencies such as events. This ensures that + * the scheduler does not serialize them in an improper order. * * \param stream The stream to synchronize on the memory location. * \param addr The memory location to wait on. @@ -13275,6 +14831,14 @@ CUresult CUDAAPI cuStreamWaitValue32(CUstream stream, CUdeviceptr addr, cuuint32 * Support for this can be queried with ::cuDeviceGetAttribute() and * ::CU_DEVICE_ATTRIBUTE_CAN_USE_64_BIT_STREAM_MEM_OPS. * + * \note + * Warning: + * Improper use of this API may deadlock the application. Synchronization + * ordering established through this API is not visible to CUDA. CUDA tasks + * that are (even indirectly) ordered by this API should also have that order + * expressed with CUDA-visible dependencies such as events. This ensures that + * the scheduler does not serialize them in an improper order. + * * \param stream The stream to synchronize on the memory location. * \param addr The memory location to wait on. * \param value The value to compare with the memory location. @@ -13298,18 +14862,12 @@ CUresult CUDAAPI cuStreamWaitValue64(CUstream stream, CUdeviceptr addr, cuuint64 /** * \brief Write a value to memory * - * Write a value to memory. Unless the ::CU_STREAM_WRITE_VALUE_NO_MEMORY_BARRIER - * flag is passed, the write is preceded by a system-wide memory fence, - * equivalent to a __threadfence_system() but scoped to the stream - * rather than a CUDA thread. + * Write a value to memory. * * If the memory was registered via ::cuMemHostRegister(), the device pointer * should be obtained with ::cuMemHostGetDevicePointer(). This function cannot * be used with managed memory (::cuMemAllocManaged). * - * Support for this can be queried with ::cuDeviceGetAttribute() and - * ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_MEM_OPS. - * * \param stream The stream to do the write in. * \param addr The device address to write to. * \param value The value to write. @@ -13333,10 +14891,7 @@ CUresult CUDAAPI cuStreamWriteValue32(CUstream stream, CUdeviceptr addr, cuuint3 /** * \brief Write a value to memory * - * Write a value to memory. Unless the ::CU_STREAM_WRITE_VALUE_NO_MEMORY_BARRIER - * flag is passed, the write is preceded by a system-wide memory fence, - * equivalent to a __threadfence_system() but scoped to the stream - * rather than a CUDA thread. + * Write a value to memory. * * If the memory was registered via ::cuMemHostRegister(), the device pointer * should be obtained with ::cuMemHostGetDevicePointer(). @@ -13376,9 +14931,17 @@ CUresult CUDAAPI cuStreamWriteValue64(CUstream stream, CUdeviceptr addr, cuuint6 * ::cuStreamWaitValue32(), ::cuStreamWaitValue64(), ::cuStreamWriteValue32(), * and ::cuStreamWriteValue64() for details of specific operations. * - * Basic support for this can be queried with ::cuDeviceGetAttribute() and - * ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_MEM_OPS. See related APIs for details - * on querying support for specific operations. + * See related APIs for details on querying support for specific operations. + * + * \note + * Warning: + * Improper use of this API may deadlock the application. Synchronization + * ordering established through this API is not visible to CUDA. CUDA tasks + * that are (even indirectly) ordered by this API should also have that order + * expressed with CUDA-visible dependencies such as events. This ensures that + * the scheduler does not serialize them in an improper order. For more + * information, see the Stream Memory Operations section in the programming + * guide(https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html). * * \param stream The stream to enqueue the operations in. * \param count The number of operations in the array. Must be less than 256. @@ -13449,6 +15012,28 @@ CUresult CUDAAPI cuStreamBatchMemOp(CUstream stream, unsigned int count, CUstrea * dynamically-allocated shared memory. * - ::CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT: Preferred shared memory-L1 * cache split ratio in percent of total shared memory. + * - ::CU_FUNC_ATTRIBUTE_CLUSTER_SIZE_MUST_BE_SET: If this attribute is set, the + * kernel must launch with a valid cluster size specified. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_WIDTH: The required cluster width in + * blocks. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_HEIGHT: The required cluster height in + * blocks. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_DEPTH: The required cluster depth in + * blocks. + * - ::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED: Indicates whether + * the function can be launched with non-portable cluster size. 1 is allowed, + * 0 is disallowed. A non-portable cluster size may only function on the + * specific SKUs the program is tested on. The launch might fail if the + * program is run on a different hardware platform. CUDA API provides + * cudaOccupancyMaxActiveClusters to assist with checking whether the desired + * size can be launched on the current device. A portable cluster size is + * guaranteed to be functional on all compute capabilities higher than the + * target compute capability. The portable cluster size for sm_90 is 8 blocks + * per cluster. This value may increase for future compute capabilities. The + * specific hardware unit may support higher cluster sizes that’s not + * guaranteed to be portable. + * - ::CU_FUNC_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE: The block + * scheduling policy of a function. The value type is CUclusterSchedulingPolicy. * * \param pi - Returned attribute value * \param attrib - Attribute requested @@ -13468,7 +15053,8 @@ CUresult CUDAAPI cuStreamBatchMemOp(CUstream stream, unsigned int count, CUstrea * ::cuFuncSetCacheConfig, * ::cuLaunchKernel, * ::cudaFuncGetAttributes, - * ::cudaFuncSetAttribute + * ::cudaFuncSetAttribute, + * ::cuKernelGetAttribute */ CUresult CUDAAPI cuFuncGetAttribute(int *pi, CUfunction_attribute attrib, CUfunction hfunc); @@ -13495,6 +15081,23 @@ CUresult CUDAAPI cuFuncGetAttribute(int *pi, CUfunction_attribute attrib, CUfunc * carveout preference, in percent of the total shared memory. * See ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR * This is only a hint, and the driver can choose a different ratio if required to execute the function. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_WIDTH: The required cluster width in + * blocks. The width, height, and depth values must either all be 0 or all be + * positive. The validity of the cluster dimensions is checked at launch time. + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime will return CUDA_ERROR_NOT_PERMITTED. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_HEIGHT: The required cluster height in + * blocks. The width, height, and depth values must either all be 0 or all be + * positive. The validity of the cluster dimensions is checked at launch time. + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime will return CUDA_ERROR_NOT_PERMITTED. + * - ::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_DEPTH: The required cluster depth in + * blocks. The width, height, and depth values must either all be 0 or all be + * positive. The validity of the cluster dimensions is checked at launch time. + * If the value is set during compile time, it cannot be set at runtime. + * Setting it at runtime will return CUDA_ERROR_NOT_PERMITTED. + * - ::CU_FUNC_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE: The block + * scheduling policy of a function. The value type is CUclusterSchedulingPolicy. * * \param hfunc - Function to query attribute of * \param attrib - Attribute requested @@ -13514,7 +15117,8 @@ CUresult CUDAAPI cuFuncGetAttribute(int *pi, CUfunction_attribute attrib, CUfunc * ::cuFuncSetCacheConfig, * ::cuLaunchKernel, * ::cudaFuncGetAttributes, - * ::cudaFuncSetAttribute + * ::cudaFuncSetAttribute, + * ::cuKernelSetAttribute */ CUresult CUDAAPI cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value); @@ -13558,7 +15162,8 @@ CUresult CUDAAPI cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attri * ::cuCtxSetCacheConfig, * ::cuFuncGetAttribute, * ::cuLaunchKernel, - * ::cudaFuncSetCacheConfig + * ::cudaFuncSetCacheConfig, + * ::cuKernelSetCacheConfig */ CUresult CUDAAPI cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config); @@ -13641,10 +15246,11 @@ CUresult CUDAAPI cuFuncSetSharedMemConfig(CUfunction hfunc, CUsharedconfig confi CUresult CUDAAPI cuFuncGetModule(CUmodule *hmod, CUfunction hfunc); /** - * \brief Launches a CUDA function + * \brief Launches a CUDA function ::CUfunction or a CUDA kernel ::CUkernel * - * Invokes the kernel \p f on a \p gridDimX x \p gridDimY x \p gridDimZ - * grid of blocks. Each block contains \p blockDimX x \p blockDimY x + * Invokes the function ::CUfunction or the kernel ::CUkernel \p f + * on a \p gridDimX x \p gridDimY x \p gridDimZ grid of blocks. + * Each block contains \p blockDimX x \p blockDimY x * \p blockDimZ threads. * * \p sharedMemBytes sets the amount of dynamic shared memory that will be @@ -13713,15 +15319,213 @@ CUresult CUDAAPI cuFuncGetModule(CUmodule *hmod, CUfunction hfunc); * If either of these conditions is not met, then ::cuLaunchKernel() will * return ::CUDA_ERROR_INVALID_IMAGE. * - * \param f - Kernel to launch - * \param gridDimX - Width of grid in blocks - * \param gridDimY - Height of grid in blocks - * \param gridDimZ - Depth of grid in blocks - * \param blockDimX - X dimension of each thread block - * \param blockDimY - Y dimension of each thread block - * \param blockDimZ - Z dimension of each thread block - * \param sharedMemBytes - Dynamic shared-memory size per thread block in bytes - * \param hStream - Stream identifier + * Note that the API can also be used to launch context-less kernel ::CUkernel + * by querying the handle using ::cuLibraryGetKernel() and then passing it + * to the API by casting to ::CUfunction. Here, the context to launch + * the kernel on will either be taken from the specified stream \p hStream + * or the current context in case of NULL stream. + * + * \param f - Function ::CUfunction or Kernel ::CUkernel to launch + * \param gridDimX - Width of grid in blocks + * \param gridDimY - Height of grid in blocks + * \param gridDimZ - Depth of grid in blocks + * \param blockDimX - X dimension of each thread block + * \param blockDimY - Y dimension of each thread block + * \param blockDimZ - Z dimension of each thread block + * \param sharedMemBytes - Dynamic shared-memory size per thread block in bytes + * \param hStream - Stream identifier + * \param kernelParams - Array of pointers to kernel parameters + * \param extra - Extra options + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_IMAGE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_LAUNCH_FAILED, + * ::CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES, + * ::CUDA_ERROR_LAUNCH_TIMEOUT, + * ::CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_NOT_FOUND + * \note_null_stream + * \notefnerr + * + * \sa ::cuCtxGetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuFuncSetCacheConfig, + * ::cuFuncGetAttribute, + * ::cudaLaunchKernel, + * ::cuLibraryGetKernel, + * ::cuKernelSetCacheConfig, + * ::cuKernelGetAttribute, + * ::cuKernelSetAttribute + */ +CUresult CUDAAPI cuLaunchKernel(CUfunction f, + unsigned int gridDimX, + unsigned int gridDimY, + unsigned int gridDimZ, + unsigned int blockDimX, + unsigned int blockDimY, + unsigned int blockDimZ, + unsigned int sharedMemBytes, + CUstream hStream, + void **kernelParams, + void **extra); + +/** + * \brief Launches a CUDA function ::CUfunction or a CUDA kernel ::CUkernel with launch-time configuration + * + * Invokes the function ::CUfunction or the kernel ::CUkernel \p f with the specified launch-time configuration + * \p config. + * + * The ::CUlaunchConfig structure is defined as: + * \code + typedef struct CUlaunchConfig_st { + unsigned int gridDimX; + unsigned int gridDimY; + unsigned int gridDimZ; + unsigned int blockDimX; + unsigned int blockDimY; + unsigned int blockDimZ; + unsigned int sharedMemBytes; + CUstream hStream; + CUlaunchAttribute *attrs; + unsigned int numAttrs; + } CUlaunchConfig; + * \endcode + * where: + * - ::CUlaunchConfig::gridDimX is the width of the grid in blocks. + * - ::CUlaunchConfig::gridDimY is the height of the grid in blocks. + * - ::CUlaunchConfig::gridDimZ is the depth of the grid in blocks. + * - ::CUlaunchConfig::blockDimX is the X dimension of each thread block. + * - ::CUlaunchConfig::blockDimX is the Y dimension of each thread block. + * - ::CUlaunchConfig::blockDimZ is the Z dimension of each thread block. + * - ::CUlaunchConfig::sharedMemBytes is the dynamic shared-memory size per + * thread block in bytes. + * - ::CUlaunchConfig::hStream is the handle to the stream to perform the launch + * in. The CUDA context associated with this stream must match that associated + * with function f. + * - ::CUlaunchConfig::attrs is an array of ::CUlaunchConfig::numAttrs + * continguous ::CUlaunchAttribute elements. The value of this pointer is not + * considered if ::CUlaunchConfig::numAttrs is zero. However, in that case, it + * is recommended to set the pointer to NULL. + * - ::CUlaunchConfig::numAttrs is the numbers of attributes populating the + * first ::CUlaunchConfig::numAttrs positions of the ::CUlaunchConfig::attrs + * array. + * + * Launch-time configuration is specified by adding entries to + * ::CUlaunchConfig::attrs. Each entry is an attribute ID and a corresponding + * attribute value. + * + * The ::CUlaunchAttribute structure is defined as: + * \code + typedef struct CUlaunchAttribute_st { + CUlaunchAttributeID id; + CUlaunchAttributeValue value; + } CUlaunchAttribute; + * \endcode + * where: + * - ::CUlaunchAttribute::id is a unique enum identifying the attribute. + * - ::CUlaunchAttribute::value is a union that hold the attribute value. + * + * An example of using the \p config parameter: + * \code + CUlaunchAttribute coopAttr = {.id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, + .value = 1}; + CUlaunchConfig config = {... // set block and grid dimensions + .attrs = &coopAttr, + .numAttrs = 1}; + + cuLaunchKernelEx(&config, kernel, NULL, NULL); + * \endcode + * + * The ::CUlaunchAttributeID enum is defined as: + * \code + typedef enum CUlaunchAttributeID_enum { + CU_LAUNCH_ATTRIBUTE_IGNORE = 0, + CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW = 1, + CU_LAUNCH_ATTRIBUTE_COOPERATIVE = 2, + CU_LAUNCH_ATTRIBUTE_SYNCHRONIZATION_POLICY = 3, + CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION = 4, + CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE = 5, + CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION = 6, + CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_EVENT = 7, + } CUlaunchAttributeID; + * \endcode + * + * and the corresponding ::CUlaunchAttributeValue union as : + * \code + typedef union CUlaunchAttributeValue_union { + cuuint64_t pad[8]; + CUaccessPolicyWindow accessPolicyWindow; + int cooperative; + CUsynchronizationPolicy syncPolicy; + struct { + unsigned int x; + unsigned int y; + unsigned int z; + } clusterDim; + CUclusterSchedulingPolicy clusterSchedulingPolicyPreference; + int programmaticStreamSerializationAllowed; + struct { + CUevent event; + int flags; + int triggerAtBlockStart; + } programmaticEvent; + } CUlaunchAttributeValue; + * \endcode + * + * Setting ::CU_LAUNCH_ATTRIBUTE_COOPERATIVE to a non-zero value causes the + * kernel launch to be a cooperative launch, with exactly the same usage and + * semantics of ::cuLaunchCooperativeKernel. + * + * Setting ::CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION to a non-zero + * values causes the kernel to use programmatic means to resolve its stream + * dependency -- enabling the CUDA runtime to opportunistically allow the grid's + * execution to overlap with the previous kernel in the stream, if that kernel + * requests the overlap. + * + * ::CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_EVENT records an event along with the + * kernel launch. Event recorded through this launch attribute is guaranteed to + * only trigger after all block in the associated kernel trigger the event. A + * block can trigger the event through PTX launchdep.release or CUDA builtin + * function cudaTriggerProgrammaticLaunchCompletion(). A trigger can also be + * inserted at the beginning of each block's execution if triggerAtBlockStart is + * set to non-0. Note that dependents (including the CPU thread calling + * cuEventSynchronize()) are not guaranteed to observe the release precisely + * when it is released. For example, cuEventSynchronize() may only observe the + * event trigger long after the associated kernel has completed. This recording + * type is primarily meant for establishing programmatic dependency between + * device tasks. The event supplied must not be an interprocess or interop + * event. The event must disable timing (i.e. created with + * ::CU_EVENT_DISABLE_TIMING flag set). + * + * The effect of other attributes is consistent with their effect when set via + * persistent APIs. + * + * See ::cuStreamSetAttribute for + * - ::CU_LAUNCH_ATTRIBUTE_ACCESS_POLICY_WINDOW + * - ::CU_LAUNCH_ATTRIBUTE_SYNCHRONIZATION_POLICY + * + * See ::cuFunctionSetAttribute for + * - ::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION + * - ::CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE + * + * Kernel parameters to \p f can be specified in the same ways that they can be + * using ::cuLaunchKernel. + * + * Note that the API can also be used to launch context-less kernel ::CUkernel + * by querying the handle using ::cuLibraryGetKernel() and then passing it + * to the API by casting to ::CUfunction. Here, the context to launch + * the kernel on will either be taken from the specified stream ::CUlaunchConfig::hStream + * or the current context in case of NULL stream. + * + * \param config - Config to launch + * \param f - Function ::CUfunction or Kernel ::CUkernel to launch * \param kernelParams - Array of pointers to kernel parameters * \param extra - Extra options * @@ -13737,7 +15541,9 @@ CUresult CUDAAPI cuFuncGetModule(CUmodule *hmod, CUfunction hfunc); * ::CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES, * ::CUDA_ERROR_LAUNCH_TIMEOUT, * ::CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING, - * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED + * ::CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_NOT_FOUND * \note_null_stream * \notefnerr * @@ -13745,34 +15551,32 @@ CUresult CUDAAPI cuFuncGetModule(CUmodule *hmod, CUfunction hfunc); * ::cuCtxSetCacheConfig, * ::cuFuncSetCacheConfig, * ::cuFuncGetAttribute, - * ::cudaLaunchKernel + * ::cudaLaunchKernel, + * ::cudaLaunchKernelEx, + * ::cuLibraryGetKernel, + * ::cuKernelSetCacheConfig, + * ::cuKernelGetAttribute, + * ::cuKernelSetAttribute */ -CUresult CUDAAPI cuLaunchKernel(CUfunction f, - unsigned int gridDimX, - unsigned int gridDimY, - unsigned int gridDimZ, - unsigned int blockDimX, - unsigned int blockDimY, - unsigned int blockDimZ, - unsigned int sharedMemBytes, - CUstream hStream, - void **kernelParams, - void **extra); - - - - - - - +CUresult CUDAAPI cuLaunchKernelEx(const CUlaunchConfig *config, + CUfunction f, + void **kernelParams, + void **extra); /** - * \brief Launches a CUDA function where thread blocks can cooperate and synchronize as they execute + * \brief Launches a CUDA function ::CUfunction or a CUDA kernel ::CUkernel where thread blocks + * can cooperate and synchronize as they execute * - * Invokes the kernel \p f on a \p gridDimX x \p gridDimY x \p gridDimZ + * Invokes the function ::CUfunction or the kernel ::CUkernel \p f on a \p gridDimX x \p gridDimY x \p gridDimZ * grid of blocks. Each block contains \p blockDimX x \p blockDimY x * \p blockDimZ threads. * + * Note that the API can also be used to launch context-less kernel ::CUkernel + * by querying the handle using ::cuLibraryGetKernel() and then passing it + * to the API by casting to ::CUfunction. Here, the context to launch + * the kernel on will either be taken from the specified stream \p hStream + * or the current context in case of NULL stream. + * * \p sharedMemBytes sets the amount of dynamic shared memory that will be * available to each thread block. * @@ -13807,7 +15611,13 @@ CUresult CUDAAPI cuLaunchKernel(CUfunction f, * If either of these conditions is not met, then ::cuLaunchCooperativeKernel() will * return ::CUDA_ERROR_INVALID_IMAGE. * - * \param f - Kernel to launch + * Note that the API can also be used to launch context-less kernel ::CUkernel + * by querying the handle using ::cuLibraryGetKernel() and then passing it + * to the API by casting to ::CUfunction. Here, the context to launch + * the kernel on will either be taken from the specified stream \p hStream + * or the current context in case of NULL stream. + * + * \param f - Function ::CUfunction or Kernel ::CUkernel to launch * \param gridDimX - Width of grid in blocks * \param gridDimY - Height of grid in blocks * \param gridDimZ - Depth of grid in blocks @@ -13831,7 +15641,8 @@ CUresult CUDAAPI cuLaunchKernel(CUfunction f, * ::CUDA_ERROR_LAUNCH_TIMEOUT, * ::CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING, * ::CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE, - * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_NOT_FOUND * \note_null_stream * \notefnerr * @@ -13840,7 +15651,11 @@ CUresult CUDAAPI cuLaunchKernel(CUfunction f, * ::cuFuncSetCacheConfig, * ::cuFuncGetAttribute, * ::cuLaunchCooperativeKernelMultiDevice, - * ::cudaLaunchCooperativeKernel + * ::cudaLaunchCooperativeKernel, + * ::cuLibraryGetKernel, + * ::cuKernelSetCacheConfig, + * ::cuKernelGetAttribute, + * ::cuKernelSetAttribute */ CUresult CUDAAPI cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, @@ -13870,7 +15685,7 @@ CUresult CUDAAPI cuLaunchCooperativeKernel(CUfunction f, * All kernels launched must be identical with respect to the compiled code. Note that * any __device__, __constant__ or __managed__ variables present in the module that owns * the kernel launched on each device, are independently instantiated on every device. - * It is the application's responsiblity to ensure these variables are initialized and + * It is the application's responsibility to ensure these variables are initialized and * used appropriately. * * The size of the grids as specified in blocks, the size of the blocks themselves @@ -13910,6 +15725,9 @@ CUresult CUDAAPI cuLaunchCooperativeKernel(CUfunction f, * where: * - ::CUDA_LAUNCH_PARAMS::function specifies the kernel to be launched. All functions must * be identical with respect to the compiled code. + * Note that you can also specify context-less kernel ::CUkernel by querying the handle + * using ::cuLibraryGetKernel() and then casting to ::CUfunction. In this case, the context to + * launch the kernel on be taken from the specified stream ::CUDA_LAUNCH_PARAMS::hStream. * - ::CUDA_LAUNCH_PARAMS::gridDimX is the width of the grid in blocks. This must match across * all kernels launched. * - ::CUDA_LAUNCH_PARAMS::gridDimY is the height of the grid in blocks. This must match across @@ -15077,7 +16895,7 @@ CUresult CUDAAPI cuGraphAddEmptyNode(CUgraphNode *phGraphNode, CUgraph hGraph, c * ::cuGraphAddEmptyNode, * ::cuGraphAddKernelNode, * ::cuGraphAddMemcpyNode, - * ::cuGraphAddMemsetNode, + * ::cuGraphAddMemsetNode */ CUresult CUDAAPI cuGraphAddEventRecordNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, CUevent event); @@ -15169,7 +16987,7 @@ CUresult CUDAAPI cuGraphEventRecordNodeSetEvent(CUgraphNode hNode, CUevent event * ::cuGraphAddEmptyNode, * ::cuGraphAddKernelNode, * ::cuGraphAddMemcpyNode, - * ::cuGraphAddMemsetNode, + * ::cuGraphAddMemsetNode */ CUresult CUDAAPI cuGraphAddEventWaitNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, CUevent event); @@ -15267,7 +17085,7 @@ CUresult CUDAAPI cuGraphEventWaitNodeSetEvent(CUgraphNode hNode, CUevent event); * ::cuGraphAddEmptyNode, * ::cuGraphAddKernelNode, * ::cuGraphAddMemcpyNode, - * ::cuGraphAddMemsetNode, + * ::cuGraphAddMemsetNode */ CUresult CUDAAPI cuGraphAddExternalSemaphoresSignalNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, const CUDA_EXT_SEM_SIGNAL_NODE_PARAMS *nodeParams); @@ -15371,7 +17189,7 @@ CUresult CUDAAPI cuGraphExternalSemaphoresSignalNodeSetParams(CUgraphNode hNode, * ::cuGraphAddEmptyNode, * ::cuGraphAddKernelNode, * ::cuGraphAddMemcpyNode, - * ::cuGraphAddMemsetNode, + * ::cuGraphAddMemsetNode */ CUresult CUDAAPI cuGraphAddExternalSemaphoresWaitNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, const CUDA_EXT_SEM_WAIT_NODE_PARAMS *nodeParams); @@ -15431,6 +17249,161 @@ CUresult CUDAAPI cuGraphExternalSemaphoresWaitNodeGetParams(CUgraphNode hNode, C */ CUresult CUDAAPI cuGraphExternalSemaphoresWaitNodeSetParams(CUgraphNode hNode, const CUDA_EXT_SEM_WAIT_NODE_PARAMS *nodeParams); +/** + * \brief Creates a batch memory operation node and adds it to a graph + * + * Creates a new batch memory operation node and adds it to \p hGraph with \p + * numDependencies dependencies specified via \p dependencies and arguments specified in \p nodeParams. + * It is possible for \p numDependencies to be 0, in which case the node will be placed + * at the root of the graph. \p dependencies may not have any duplicate entries. + * A handle to the new node will be returned in \p phGraphNode. + * + * When the node is added, the paramArray inside \p nodeParams is copied and therefore it can be + * freed after the call returns. + * + * \note + * Warning: + * Improper use of this API may deadlock the application. Synchronization + * ordering established through this API is not visible to CUDA. CUDA tasks + * that are (even indirectly) ordered by this API should also have that order + * expressed with CUDA-visible dependencies such as events. This ensures that + * the scheduler does not serialize them in an improper order. For more + * information, see the Stream Memory Operations section in the programming + * guide(https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html). + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param nodeParams - Parameters for the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuStreamBatchMemOp, + * ::cuStreamWaitValue32, + * ::cuStreamWriteValue32, + * ::cuStreamWaitValue64, + * ::cuStreamWriteValue64, + * ::cuGraphBatchMemOpNodeGetParams, + * ::cuGraphBatchMemOpNodeSetParams, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddBatchMemOpNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, const CUDA_BATCH_MEM_OP_NODE_PARAMS *nodeParams); + +/** + * \brief Returns a batch mem op node's parameters + * + * Returns the parameters of batch mem op node \p hNode in \p nodeParams_out. + * The \p paramArray returned in \p nodeParams_out is owned by the node. + * This memory remains valid until the node is destroyed or its + * parameters are modified, and should not be modified + * directly. Use ::cuGraphBatchMemOpNodeSetParams to update the + * parameters of this node. + * + * \param hNode - Node to get the parameters for + * \param nodeParams_out - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuStreamBatchMemOp, + * ::cuGraphAddBatchMemOpNode, + * ::cuGraphBatchMemOpNodeSetParams + */ +CUresult CUDAAPI cuGraphBatchMemOpNodeGetParams(CUgraphNode hNode, CUDA_BATCH_MEM_OP_NODE_PARAMS *nodeParams_out); + +/** + * \brief Sets a batch mem op node's parameters + * + * Sets the parameters of batch mem op node \p hNode to \p nodeParams. + * + * The paramArray inside \p nodeParams is copied and therefore it can be + * freed after the call returns. + * + * \param hNode - Node to set the parameters for + * \param nodeParams - Parameters to copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuStreamBatchMemOp, + * ::cuGraphAddBatchMemOpNode, + * ::cuGraphBatchMemOpNodeGetParams + */ +CUresult CUDAAPI cuGraphBatchMemOpNodeSetParams(CUgraphNode hNode, const CUDA_BATCH_MEM_OP_NODE_PARAMS *nodeParams); + +/** + * \brief Sets the parameters for a batch mem op node in the given graphExec + * + * Sets the parameters of a batch mem op node in an executable graph \p hGraphExec. + * The node is identified by the corresponding node \p hNode in the + * non-executable graph, from which the executable graph was instantiated. + * + * The following fields on operations may be modified on an executable graph: + * + * op.waitValue.address + * op.waitValue.value[64] + * op.waitValue.flags bits corresponding to wait type (i.e. CU_STREAM_WAIT_VALUE_FLUSH bit cannot be modified) + * op.writeValue.address + * op.writeValue.value[64] + * + * Other fields, such as the context, count or type of operations, and other types of operations such as membars, + * may not be modified. + * + * \p hNode must not have been removed from the original graph. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * The paramArray inside \p nodeParams is copied and therefore it can be + * freed after the call returns. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Batch mem op node from the graph from which graphExec was instantiated + * \param nodeParams - Updated Parameters to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuStreamBatchMemOp, + * ::cuGraphAddBatchMemOpNode, + * ::cuGraphBatchMemOpNodeGetParams, + * ::cuGraphBatchMemOpNodeSetParams, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecBatchMemOpNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, const CUDA_BATCH_MEM_OP_NODE_PARAMS *nodeParams); + /** * \brief Creates an allocation node and adds it to a graph * @@ -16019,18 +17992,50 @@ CUresult CUDAAPI cuGraphDestroyNode(CUgraphNode hNode); * validated. If instantiation is successful, a handle to the instantiated graph * is returned in \p phGraphExec. * - * If there are any errors, diagnostic information may be returned in \p errorNode and - * \p logBuffer. This is the primary way to inspect instantiation errors. The output - * will be null terminated unless the diagnostics overflow - * the buffer. In this case, they will be truncated, and the last byte can be - * inspected to determine if truncation occurred. + * The \p flags parameter controls the behavior of instantiation and subsequent + * graph launches. Valid flags are: + * + * - ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH, which configures a + * graph containing memory allocation nodes to automatically free any + * unfreed memory allocations before the graph is relaunched. + * + * - ::CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH, which configures the graph for launch + * from the device. If this flag is passed, the executable graph handle returned can be + * used to launch the graph from both the host and device. This flag can only be used + * on platforms which support unified addressing. This flag cannot be used in + * conjunction with ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH. + * + * - ::CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY, which causes the graph + * to use the priorities from the per-node attributes rather than the priority + * of the launch stream during execution. Note that priorities are only available + * on kernel nodes, and are copied from stream priority during stream capture. + * + * If \p hGraph contains any allocation or free nodes, there can be at most one + * executable graph in existence for that graph at a time. An attempt to instantiate + * a second executable graph before destroying the first with ::cuGraphExecDestroy + * will result in an error. + * + * If \p hGraph contains kernels which call device-side cudaGraphLaunch() from multiple + * contexts, this will result in an error. + * + * Graphs instantiated for launch on the device have additional restrictions which do not + * apply to host graphs: + * + * - The graph's nodes must reside on a single context. + * - The graph can only contain kernel nodes, memcpy nodes, memset nodes, and child graph nodes. + * Operation-specific restrictions are outlined below. + * - Kernel nodes: + * - Use of CUDA Dynamic Parallelism is not permitted. + * - Cooperative launches are permitted as long as MPS is not in use. + * - Memcpy nodes: + * - Only copies involving device memory and/or pinned device-mapped host memory are permitted. + * - Copies involving CUDA arrays are not permitted. + * - Both operands must be accessible from the current context, and the current context must + * match the context of other nodes in the graph. * * \param phGraphExec - Returns instantiated graph * \param hGraph - Graph to instantiate - * \param phErrorNode - In case of an instantiation error, this may be modified to - * indicate a node contributing to the error - * \param logBuffer - A character buffer to store diagnostic messages - * \param bufferSize - Size of the log buffer in bytes + * \param flags - Flags to control instantiation. See ::CUgraphInstantiate_flags. * * \return * ::CUDA_SUCCESS, @@ -16041,55 +18046,136 @@ CUresult CUDAAPI cuGraphDestroyNode(CUgraphNode hNode); * \notefnerr * * \sa - * ::cuGraphInstantiateWithFlags, + * ::cuGraphInstantiate, * ::cuGraphCreate, * ::cuGraphUpload, * ::cuGraphLaunch, * ::cuGraphExecDestroy */ -CUresult CUDAAPI cuGraphInstantiate(CUgraphExec *phGraphExec, CUgraph hGraph, CUgraphNode *phErrorNode, char *logBuffer, size_t bufferSize); +CUresult CUDAAPI cuGraphInstantiate(CUgraphExec *phGraphExec, CUgraph hGraph, unsigned long long flags); /** * \brief Creates an executable graph from a graph * - * Instantiates \p hGraph as an executable graph. The graph is validated for any - * structural constraints or intra-node constraints which were not previously - * validated. If instantiation is successful, a handle to the instantiated graph - * is returned in \p phGraphExec. + * Instantiates \p hGraph as an executable graph according to the \p instantiateParams structure. + * The graph is validated for any structural constraints or intra-node constraints + * which were not previously validated. If instantiation is successful, a handle to + * the instantiated graph is returned in \p phGraphExec. * - * The \p flags parameter controls the behavior of instantiation and subsequent - * graph launches. Valid flags are: + * \p instantiateParams controls the behavior of instantiation and subsequent + * graph launches, as well as returning more detailed information in the event of an error. + * ::CUDA_GRAPH_INSTANTIATE_PARAMS is defined as: + * + * \code + typedef struct { + cuuint64_t flags; + CUstream hUploadStream; + CUgraphNode hErrNode_out; + CUgraphInstantiateResult result_out; + } CUDA_GRAPH_INSTANTIATE_PARAMS; + * \endcode + * + * The \p flags field controls the behavior of instantiation and subsequent + * graph launches. Valid flags are: * * - ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH, which configures a * graph containing memory allocation nodes to automatically free any * unfreed memory allocations before the graph is relaunched. * + * - ::CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD, which will perform an upload of the graph + * into \p hUploadStream once the graph has been instantiated. + * + * - ::CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH, which configures the graph for launch + * from the device. If this flag is passed, the executable graph handle returned can be + * used to launch the graph from both the host and device. This flag can only be used + * on platforms which support unified addressing. This flag cannot be used in + * conjunction with ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH. + * + * - ::CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY, which causes the graph + * to use the priorities from the per-node attributes rather than the priority + * of the launch stream during execution. Note that priorities are only available + * on kernel nodes, and are copied from stream priority during stream capture. + * * If \p hGraph contains any allocation or free nodes, there can be at most one - * executable graph in existence for that graph at a time. + * executable graph in existence for that graph at a time. An attempt to instantiate a + * second executable graph before destroying the first with ::cuGraphExecDestroy will + * result in an error. * - * An attempt to instantiate a second executable graph before destroying the first - * with ::cuGraphExecDestroy will result in an error. + * If \p hGraph contains kernels which call device-side cudaGraphLaunch() from multiple + * contexts, this will result in an error. * - * \param phGraphExec - Returns instantiated graph - * \param hGraph - Graph to instantiate - * \param flags - Flags to control instantiation. See ::CUgraphInstantiate_flags. + * Graphs instantiated for launch on the device have additional restrictions which do not + * apply to host graphs: + * + * - The graph's nodes must reside on a single context. + * - The graph can only contain kernel nodes, memcpy nodes, memset nodes, and child graph nodes. + * Operation-specific restrictions are outlined below. + * - Kernel nodes: + * - Use of CUDA Dynamic Parallelism is not permitted. + * - Cooperative launches are permitted as long as MPS is not in use. + * - Memcpy nodes: + * - Only copies involving device memory and/or pinned device-mapped host memory are permitted. + * - Copies involving CUDA arrays are not permitted. + * - Both operands must be accessible from the current context, and the current context must + * match the context of other nodes in the graph. + * + * In the event of an error, the \p result_out and \p hErrNode_out fields will contain more + * information about the nature of the error. Possible error reporting includes: + * + * - ::CUDA_GRAPH_INSTANTIATE_ERROR, if passed an invalid value or if an unexpected error occurred + * which is described by the return value of the function. \p hErrNode_out will be set to NULL. + * - ::CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE, if the graph structure is invalid. \p hErrNode_out + * will be set to one of the offending nodes. + * - ::CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED, if the graph is instantiated for device + * launch but contains a node of an unsupported node type, or a node which performs unsupported + * operations, such as use of CUDA dynamic parallelism within a kernel node. \p hErrNode_out will + * be set to this node. + * - ::CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED, if the graph is instantiated for device + * launch but a node’s context differs from that of another node. This error can also be returned + * if a graph is not instantiated for device launch and it contains kernels which call device-side + * cudaGraphLaunch() from multiple contexts. \p hErrNode_out will be set to this node. + * + * If instantiation is successful, \p result_out will be set to ::CUDA_GRAPH_INSTANTIATE_SUCCESS, + * and \p hErrNode_out will be set to NULL. + * + * \param phGraphExec - Returns instantiated graph + * \param hGraph - Graph to instantiate + * \param instantiateParams - Instantiation parameters * * \return * ::CUDA_SUCCESS, - * ::CUDA_ERROR_DEINITIALIZED, - * ::CUDA_ERROR_NOT_INITIALIZED, - * ::CUDA_ERROR_INVALID_VALUE + * ::CUDA_ERROR_INVALID_VALUE, * \note_graph_thread_safety * \notefnerr * * \sa - * ::cuGraphInstantiate, * ::cuGraphCreate, - * ::cuGraphUpload, - * ::cuGraphLaunch, + * ::cuGraphInstantiate, * ::cuGraphExecDestroy */ -CUresult CUDAAPI cuGraphInstantiateWithFlags(CUgraphExec *phGraphExec, CUgraph hGraph, unsigned long long flags); +CUresult CUDAAPI cuGraphInstantiateWithParams(CUgraphExec *phGraphExec, CUgraph hGraph, CUDA_GRAPH_INSTANTIATE_PARAMS *instantiateParams); + +/** + * \brief Query the instantiation flags of an executable graph + * + * Returns the flags that were passed to instantiation for the given executable graph. + * ::CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD will not be returned by this API as it does + * not affect the resulting executable graph. + * + * \param hGraphExec - The executable graph to query + * \param flags - Returns the instantiation flags + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphInstantiate, + * ::cuGraphInstantiateWithParams + */ +CUresult CUDAAPI cuGraphExecGetFlags(CUgraphExec hGraphExec, cuuint64_t *flags); /** * \brief Sets the parameters for a kernel node in the given graphExec @@ -16104,6 +18190,11 @@ CUresult CUDAAPI cuGraphInstantiateWithFlags(CUgraphExec *phGraphExec, CUgraph h * - The owning context of the function cannot change. * - A node whose function originally did not use CUDA dynamic parallelism cannot be updated * to a function which uses CDP + * - If \p hGraphExec was not instantiated for device launch, a node whose function originally + * did not use device-side cudaGraphLaunch() cannot be updated to a function which uses + * device-side cudaGraphLaunch() unless the node resides on the same context as nodes which + * contained such calls at instantiate-time. If no such calls were present at instantiation, + * these updates cannot be performed at all. * * The modifications only affect future launches of \p hGraphExec. Already * enqueued or running launches of \p hGraphExec are not affected by this call. @@ -16489,7 +18580,7 @@ CUresult CUDAAPI cuGraphExecExternalSemaphoresWaitNodeSetParams(CUgraphExec hGra * enqueued or running launches of \p hGraphExec are not affected by this call. * \p hNode is also not modified by this call. * - * \note Currently only kernel nodes are supported. + * \note Currently only kernel, memset and memcpy nodes are supported. * * \param hGraphExec - The executable graph in which to set the specified node * \param hNode - Node from the graph from which graphExec was instantiated @@ -16519,7 +18610,7 @@ CUresult CUDAAPI cuGraphNodeSetEnabled(CUgraphExec hGraphExec, CUgraphNode hNode * * \p hNode must not have been removed from the original graph. * - * \note Currently only kernel nodes are supported. + * \note Currently only kernel, memset and memcpy nodes are supported. * * \param hGraphExec - The executable graph in which to set the specified node * \param hNode - Node from the graph from which graphExec was instantiated @@ -16653,6 +18744,14 @@ CUresult CUDAAPI cuGraphDestroy(CUgraph hGraph); * - A node whose function originally did not use CUDA dynamic parallelism cannot be updated * to a function which uses CDP. * - A cooperative node cannot be updated to a non-cooperative node, and vice-versa. + * - If the graph was instantiated with CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY, the + * priority attribute cannot change. Equality is checked on the originally requested + * priority values, before they are clamped to the device's supported range. + * - If \p hGraphExec was not instantiated for device launch, a node whose function originally + * did not use device-side cudaGraphLaunch() cannot be updated to a function which uses + * device-side cudaGraphLaunch() unless the node resides on the same context as nodes which + * contained such calls at instantiate-time. If no such calls were present at instantiation, + * these updates cannot be performed at all. * - Memset and memcpy nodes: * - The CUDA device(s) to which the operand(s) was allocated/mapped cannot change. * - The source/destination memory must be allocated from the same contexts as the original @@ -16666,18 +18765,20 @@ CUresult CUDAAPI cuGraphDestroy(CUgraph hGraph); * * Note: The API may add further restrictions in future releases. The return code should always be checked. * - * cuGraphExecUpdate sets \p updateResult_out to CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED under - * the following conditions: - * - * - The count of nodes directly in \p hGraphExec and \p hGraph differ, in which case \p hErrorNode_out - * is NULL. - * - A node is deleted in \p hGraph but not not its pair from \p hGraphExec, in which case \p hErrorNode_out - * is NULL. - * - A node is deleted in \p hGraphExec but not its pair from \p hGraph, in which case \p hErrorNode_out is - * the pairless node from \p hGraph. - * - The dependent nodes of a pair differ, in which case \p hErrorNode_out is the node from \p hGraph. - * - * cuGraphExecUpdate sets \p updateResult_out to: + * cuGraphExecUpdate sets the result member of \p resultInfo to CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED + * under the following conditions: + * - The count of nodes directly in \p hGraphExec and \p hGraph differ, in which case resultInfo->errorNode + * is set to NULL. + * - \p hGraph has more exit nodes than \p hGraph, in which case resultInfo->errorNode is set to one of + * the exit nodes in hGraph. + * - A node in \p hGraph has a different number of dependencies than the node from \p hGraphExec it is paired with, + * in which case resultInfo->errorNode is set to the node from \p hGraph. + * - A node in \p hGraph has a dependency that does not match with the corresponding dependency of the paired node + * from \p hGraphExec. resultInfo->errorNode will be set to the node from \p hGraph. resultInfo->errorFromNode + * will be set to the mismatched dependency. The dependencies are paired based on edge order and a dependency + * does not match when the nodes are already paired based on other edges examined in the graph. + * + * cuGraphExecUpdate sets the result member of \p resultInfo to: * - CU_GRAPH_EXEC_UPDATE_ERROR if passed an invalid value. * - CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED if the graph topology changed * - CU_GRAPH_EXEC_UPDATE_ERROR_NODE_TYPE_CHANGED if the type of a node changed, in which case @@ -16691,10 +18792,8 @@ CUresult CUDAAPI cuGraphDestroy(CUgraph hGraph); * - CU_GRAPH_EXEC_UPDATE_ERROR_NOT_SUPPORTED if something about a node is unsupported, like * the node's type or configuration, in which case \p hErrorNode_out is set to the node from \p hGraph * - * If \p updateResult_out isn't set in one of the situations described above, the update check passes - * and cuGraphExecUpdate updates \p hGraphExec to match the contents of \p hGraph. If an error happens - * during the update, \p updateResult_out will be set to CU_GRAPH_EXEC_UPDATE_ERROR; otherwise, - * \p updateResult_out is set to CU_GRAPH_EXEC_UPDATE_SUCCESS. + * If the update fails for a reason not listed above, the result member of \p resultInfo will be set + * to CU_GRAPH_EXEC_UPDATE_ERROR. If the update succeeds, the result member will be set to CU_GRAPH_EXEC_UPDATE_SUCCESS. * * cuGraphExecUpdate returns CUDA_SUCCESS when the updated was performed successfully. It returns * CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE if the graph update was not performed because it included @@ -16702,8 +18801,7 @@ CUresult CUDAAPI cuGraphDestroy(CUgraph hGraph); * * \param hGraphExec The instantiated graph to be updated * \param hGraph The graph containing the updated parameters - * \param hErrorNode_out The node which caused the permissibility check to forbid the update, if any - * \param updateResult_out Whether the graph update was permitted. If was forbidden, the reason why + * \param resultInfo the error info structure * * \return * ::CUDA_SUCCESS, @@ -16712,9 +18810,9 @@ CUresult CUDAAPI cuGraphDestroy(CUgraph hGraph); * \notefnerr * * \sa - * ::cuGraphInstantiate, + * ::cuGraphInstantiate */ -CUresult CUDAAPI cuGraphExecUpdate(CUgraphExec hGraphExec, CUgraph hGraph, CUgraphNode *hErrorNode_out, CUgraphExecUpdateResult *updateResult_out); +CUresult CUDAAPI cuGraphExecUpdate(CUgraphExec hGraphExec, CUgraph hGraph, CUgraphExecUpdateResultInfo *resultInfo); /** * \brief Copies attributes from source node to destination node. @@ -17123,12 +19221,79 @@ CUresult CUDAAPI cuOccupancyMaxPotentialBlockSizeWithFlags(int *minGridSize, int /** * \brief Returns dynamic shared memory available per block when launching \p numBlocks blocks on SM * - * Returns in \p *dynamicSmemSize the maximum size of dynamic shared memory to allow \p numBlocks blocks per SM. + * Returns in \p *dynamicSmemSize the maximum size of dynamic shared memory to allow \p numBlocks blocks per SM. + * + * \param dynamicSmemSize - Returned maximum dynamic shared memory + * \param func - Kernel function for which occupancy is calculated + * \param numBlocks - Number of blocks to fit on SM + * \param blockSize - Size of the blocks + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + */ +CUresult CUDAAPI cuOccupancyAvailableDynamicSMemPerBlock(size_t *dynamicSmemSize, CUfunction func, int numBlocks, int blockSize); + +/** + * \brief Given the kernel function (\p func) and launch configuration + * (\p config), return the maximum cluster size in \p *clusterSize. + * + * The cluster dimensions in \p config are ignored. If func has a required + * cluster size set (see ::cudaFuncGetAttributes / ::cuFuncGetAttribute),\p + * *clusterSize will reflect the required cluster size. + * + * By default this function will always return a value that's portable on + * future hardware. A higher value may be returned if the kernel function + * allows non-portable cluster sizes. + * + * This function will respect the compile time launch bounds. + * + * \param clusterSize - Returned maximum cluster size that can be launched + * for the given kernel function and launch configuration + * \param func - Kernel function for which maximum cluster + * size is calculated + * \param config - Launch configuration for the given kernel function + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + * ::cudaFuncGetAttributes, + * ::cuFuncGetAttribute + */ +CUresult CUDAAPI cuOccupancyMaxPotentialClusterSize(int *clusterSize, CUfunction func, const CUlaunchConfig *config); + +/** + * \brief Given the kernel function (\p func) and launch configuration + * (\p config), return the maximum number of clusters that could co-exist + * on the target device in \p *numClusters. + * + * If the function has required cluster size already set (see + * ::cudaFuncGetAttributes / ::cuFuncGetAttribute), the cluster size + * from config must either be unspecified or match the required size. + * Without required sizes, the cluster size must be specified in config, + * else the function will return an error. + * + * Note that various attributes of the kernel function may affect occupancy + * calculation. Runtime environment may affect how the hardware schedules + * the clusters, so the calculated occupancy is not guaranteed to be achievable. * - * \param dynamicSmemSize - Returned maximum dynamic shared memory - * \param func - Kernel function for which occupancy is calculated - * \param numBlocks - Number of blocks to fit on SM - * \param blockSize - Size of the blocks + * \param numClusters - Returned maximum number of clusters that + * could co-exist on the target device + * \param func - Kernel function for which maximum number + * of clusters are calculated + * \param config - Launch configuration for the given kernel function * * \return * ::CUDA_SUCCESS, @@ -17136,13 +19301,15 @@ CUresult CUDAAPI cuOccupancyMaxPotentialBlockSizeWithFlags(int *minGridSize, int * ::CUDA_ERROR_NOT_INITIALIZED, * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_CLUSTER_SIZE, * ::CUDA_ERROR_UNKNOWN * \notefnerr * * \sa + * ::cudaFuncGetAttributes, + * ::cuFuncGetAttribute */ -CUresult CUDAAPI cuOccupancyAvailableDynamicSMemPerBlock(size_t *dynamicSmemSize, CUfunction func, int numBlocks, int blockSize); - +CUresult CUDAAPI cuOccupancyMaxActiveClusters(int *numClusters, CUfunction func, const CUlaunchConfig *config); /** @} */ /* END CUDA_OCCUPANCY */ /** @@ -17179,12 +19346,12 @@ CUresult CUDAAPI cuOccupancyAvailableDynamicSMemPerBlock(size_t *dynamicSmemSize * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE * - * \sa ::cuTexRefSetAddress, + * \sa + * ::cuTexRefSetAddress, * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, - * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, - * ::cudaBindTextureToArray + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetArray(CUtexref hTexRef, CUarray hArray, unsigned int Flags); @@ -17209,12 +19376,12 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetArray(CUtexref hTexRef, CUarray hA * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE * - * \sa ::cuTexRefSetAddress, + * \sa + * ::cuTexRefSetAddress, * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, - * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, - * ::cudaBindTextureToMipmappedArray + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmappedArray(CUtexref hTexRef, CUmipmappedArray hMipmappedArray, unsigned int Flags); @@ -17256,11 +19423,11 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmappedArray(CUtexref hTexRef, C * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE * - * \sa ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * \sa + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, - * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, - * ::cudaBindTexture + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetAddress(size_t *ByteOffset, CUtexref hTexRef, CUdeviceptr dptr, size_t bytes); @@ -17310,12 +19477,12 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetAddress(size_t *ByteOffset, CUtexr * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE * - * \sa ::cuTexRefSetAddress, + * \sa + * ::cuTexRefSetAddress, * ::cuTexRefSetAddressMode, ::cuTexRefSetArray, * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, - * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, - * ::cudaBindTexture2D + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetAddress2D(CUtexref hTexRef, const CUDA_ARRAY_DESCRIPTOR *desc, CUdeviceptr dptr, size_t Pitch); @@ -17341,16 +19508,13 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetAddress2D(CUtexref hTexRef, const * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE * - * \sa ::cuTexRefSetAddress, + * \sa + * ::cuTexRefSetAddress, * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, - * ::cudaCreateChannelDesc, - * ::cudaBindTexture, - * ::cudaBindTexture2D, - * ::cudaBindTextureToArray, - * ::cudaBindTextureToMipmappedArray + * ::cudaCreateChannelDesc */ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetFormat(CUtexref hTexRef, CUarray_format fmt, int NumPackedComponents); @@ -17388,15 +19552,12 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetFormat(CUtexref hTexRef, CUarray_f * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE * - * \sa ::cuTexRefSetAddress, + * \sa + * ::cuTexRefSetAddress, * ::cuTexRefSetAddress2D, ::cuTexRefSetArray, * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, - * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, - * ::cudaBindTexture, - * ::cudaBindTexture2D, - * ::cudaBindTextureToArray, - * ::cudaBindTextureToMipmappedArray + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetAddressMode(CUtexref hTexRef, int dim, CUaddress_mode am); @@ -17427,12 +19588,12 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetAddressMode(CUtexref hTexRef, int * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE * - * \sa ::cuTexRefSetAddress, + * \sa + * ::cuTexRefSetAddress, * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, * ::cuTexRefSetFlags, ::cuTexRefSetFormat, * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, - * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, - * ::cudaBindTextureToArray + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetFilterMode(CUtexref hTexRef, CUfilter_mode fm); @@ -17463,12 +19624,12 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetFilterMode(CUtexref hTexRef, CUfil * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE * - * \sa ::cuTexRefSetAddress, + * \sa + * ::cuTexRefSetAddress, * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, * ::cuTexRefSetFlags, ::cuTexRefSetFormat, * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, - * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, - * ::cudaBindTextureToMipmappedArray + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmapFilterMode(CUtexref hTexRef, CUfilter_mode fm); @@ -17492,12 +19653,12 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmapFilterMode(CUtexref hTexRef, * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE * - * \sa ::cuTexRefSetAddress, + * \sa + * ::cuTexRefSetAddress, * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, * ::cuTexRefSetFlags, ::cuTexRefSetFormat, * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, - * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, - * ::cudaBindTextureToMipmappedArray + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmapLevelBias(CUtexref hTexRef, float bias); @@ -17523,12 +19684,12 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmapLevelBias(CUtexref hTexRef, * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE * - * \sa ::cuTexRefSetAddress, + * \sa + * ::cuTexRefSetAddress, * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, * ::cuTexRefSetFlags, ::cuTexRefSetFormat, * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, - * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, - * ::cudaBindTextureToMipmappedArray + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmapLevelClamp(CUtexref hTexRef, float minMipmapLevelClamp, float maxMipmapLevelClamp); @@ -17552,13 +19713,12 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmapLevelClamp(CUtexref hTexRef, * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE * - * \sa ::cuTexRefSetAddress, + * \sa + * ::cuTexRefSetAddress, * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, * ::cuTexRefSetFlags, ::cuTexRefSetFormat, * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, - * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, - * ::cudaBindTextureToArray, - * ::cudaBindTextureToMipmappedArray + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMaxAnisotropy(CUtexref hTexRef, unsigned int maxAniso); @@ -17589,12 +19749,9 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMaxAnisotropy(CUtexref hTexRef, un * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE * - * \sa ::cuTexRefSetAddressMode, - * ::cuTexRefGetAddressMode, ::cuTexRefGetBorderColor, - * ::cudaBindTexture, - * ::cudaBindTexture2D, - * ::cudaBindTextureToArray, - * ::cudaBindTextureToMipmappedArray + * \sa + * ::cuTexRefSetAddressMode, + * ::cuTexRefGetAddressMode, ::cuTexRefGetBorderColor */ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetBorderColor(CUtexref hTexRef, float *pBorderColor); @@ -17631,15 +19788,12 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetBorderColor(CUtexref hTexRef, floa * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE * - * \sa ::cuTexRefSetAddress, + * \sa + * ::cuTexRefSetAddress, * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, * ::cuTexRefSetFilterMode, ::cuTexRefSetFormat, * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, - * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, - * ::cudaBindTexture, - * ::cudaBindTexture2D, - * ::cudaBindTextureToArray, - * ::cudaBindTextureToMipmappedArray + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetFlags(CUtexref hTexRef, unsigned int Flags); @@ -18049,8 +20203,7 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefDestroy(CUtexref hTexRef); * * \sa * ::cuModuleGetSurfRef, - * ::cuSurfRefGetArray, - * ::cudaBindSurfaceToArray + * ::cuSurfRefGetArray */ __CUDA_DEPRECATED CUresult CUDAAPI cuSurfRefSetArray(CUsurfref hSurfRef, CUarray hArray, unsigned int Flags); @@ -18487,6 +20640,348 @@ CUresult CUDAAPI cuSurfObjectGetResourceDesc(CUDA_RESOURCE_DESC *pResDesc, CUsur /** @} */ /* END CUDA_SURFOBJECT */ +/** + * \defgroup CUDA_TENSOR_MEMORY Tensor Core Managment + * + * ___MANBRIEF___ tensor core management functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the tensor core management functions of the + * low-level CUDA driver application programming interface. The tensor + * core API is only supported on devices of compute capability 9.0 or higher. + * + * @{ + */ + +/** + * \brief Create a tensor map descriptor object representing tiled memory region + * + * Creates a descriptor for Tensor Memory Access (TMA) object specified + * by the parameters describing a tiled region and returns it in \p tensorMap. + * + * Tensor map objects are only supported on devices of compute capability 9.0 or higher. + * Additionally, a tensor map object is an opaque value, and, as such, should only be + * accessed through CUDA API calls. + * + * The parameters passed are bound to the following requirements: + * + * - \p tensorMap address must be aligned to 64 bytes. + * + * - \p tensorDataType has to be an enum from ::CUtensorMapDataType which is defined as: + * \code + typedef enum CUtensorMapDataType_enum { + CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0, // 1 byte + CU_TENSOR_MAP_DATA_TYPE_UINT16, // 2 bytes + CU_TENSOR_MAP_DATA_TYPE_UINT32, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_INT32, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_UINT64, // 8 bytes + CU_TENSOR_MAP_DATA_TYPE_INT64, // 8 bytes + CU_TENSOR_MAP_DATA_TYPE_FLOAT16, // 2 bytes + CU_TENSOR_MAP_DATA_TYPE_FLOAT32, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_FLOAT64, // 8 bytes + CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, // 2 bytes + CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_TFLOAT32, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ // 4 bytes + } CUtensorMapDataType; + * \endcode + * + * - \p tensorRank must be non-zero and less than or equal to the maximum supported dimensionality of 5. If \p interleave is not + * ::CU_TENSOR_MAP_INTERLEAVE_NONE, then \p tensorRank must additionally be greater than or equal to 3. + * + * - \p globalAddress, which specifies the starting address of the memory region described, must be 32 byte aligned when \p interleave is + * ::CU_TENSOR_MAP_INTERLEAVE_32B and 16 byte aligned otherwise. + * + * - \p globalDim array, which specifies tensor size of each of the \p tensorRank dimensions, must be non-zero and less than or + * equal to 2^32. + * + * - \p globalStrides array, which specifies tensor stride of each of the lower \p tensorRank - 1 dimensions in bytes, must be a + * multiple of 16 and less than 2^40. Additionally, the stride must be a multiple of 32 when \p interleave is ::CU_TENSOR_MAP_INTERLEAVE_32B. + * Each following dimension specified includes previous dimension stride: + * \code + globalStrides[0] = globalDim[0] * elementSizeInBytes(tensorDataType) + padding[0]; + for (i = 1; i < tensorRank - 1; i++) + globalStrides[i] = globalStrides[i – 1] * globalStrides[i] + padding[i]; + assert(globalStrides[i] >= globalDim[i]); + * \endcode + * + * - \p boxDim array, which specifies number of elements to be traversed along each of the \p tensorRank dimensions, must be less + * than or equal to 8. + * When \p interleave is ::CU_TENSOR_MAP_INTERLEAVE_NONE, { \p boxDim[0] * elementSizeInBytes( \p tensorDataType ) } must be a multiple + * of 16 bytes. + * + * - \p elementStrides array, which specifies the iteration step along each of the \p tensorRank dimensions, must be non-zero and less + * than or equal to 8. Note that when \p interleave is ::CU_TENSOR_MAP_INTERLEAVE_NONE, the first element of this array is ignored since + * TMA doesn’t support the stride for dimension zero. + * When all elemets of \p elementStrides array is one, \p boxDim specifies the number of elements to load. However, if the \p elementStrides[i] + * is not equal to one, then TMA loads ceil( \p boxDim[i] / \p elementStrides[i]) number of elements along i-th dimension. To load N elements along + * i-th dimension, \p boxDim[i] must be set to N * \p elementStrides[i]. + * + * - \p interleave specifies the interleaved layout of type ::CUtensorMapInterleave, which is defined as: + * \code + typedef enum CUtensorMapInterleave_enum { + CU_TENSOR_MAP_INTERLEAVE_NONE = 0, + CU_TENSOR_MAP_INTERLEAVE_16B, + CU_TENSOR_MAP_INTERLEAVE_32B + } CUtensorMapInterleave; + * \endcode + * TMA supports interleaved layouts like NC/8HWC8 where C8 utilizes 16 bytes in memory assuming 2 byte per channel or NC/16HWC16 where C16 + * uses 32 bytes. + * When \p interleave is ::CU_TENSOR_MAP_INTERLEAVE_NONE and \p swizzle is not ::CU_TENSOR_MAP_SWIZZLE_NONE, the bounding box inner dimension + * (computed as \p boxDim[0] multiplied by element size derived from \p tensorDataType) must be less than or equal to the swizzle size. + * - CU_TENSOR_MAP_SWIZZLE_32B implies the bounding box inner dimension will be <= 32. + * - CU_TENSOR_MAP_SWIZZLE_64B implies the bounding box inner dimension will be <= 64. + * - CU_TENSOR_MAP_SWIZZLE_128B implies the bounding box inner dimension will be <= 128. + * + * - \p swizzle, which specifies the shared memory bank swizzling pattern, has to be of type ::CUtensorMapSwizzle which is defined as: + * \code + typedef enum CUtensorMapSwizzle_enum { + CU_TENSOR_MAP_SWIZZLE_NONE = 0, + CU_TENSOR_MAP_SWIZZLE_32B, + CU_TENSOR_MAP_SWIZZLE_64B, + CU_TENSOR_MAP_SWIZZLE_128B + } CUtensorMapSwizzle; + * \endcode + * Data is organized in specific order in global memory; however, it may not match the order in which data are accessed by application in + * the shared memory. This difference in data organization may cause bank conflicts when shared memory is accessed. In order to avoid this + * problem, data can be loaded to shard memory with shuffling across shared memory banks. + * Note that it’s expected that when \p interleave is ::CU_TENSOR_MAP_INTERLEAVE_32B, \p swizzle should be ::CU_TENSOR_MAP_SWIZZLE_32B mode. + * Other interleave modes can have any swizzling patterns. + * + * - \p l2Promotion specifies L2 fetch size which indicates the byte granurality at which L2 requests is filled from DRAM. It must be of + * type ::CUtensorMapL2promotion, which is defined as: + * \code + typedef enum CUtensorMapL2promotion_enum { + CU_TENSOR_MAP_L2_PROMOTION_NONE = 0, + CU_TENSOR_MAP_L2_PROMOTION_L2_64B, + CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CU_TENSOR_MAP_L2_PROMOTION_L2_256B + } CUtensorMapL2promotion; + * \endcode + * + * - \p oobFill, which indicates whether zero or a special NaN constant should be used to fill out-of-bound elements, must be of type + * ::CUtensorMapFloatOOBfill which is defined as: + * \code + typedef enum CUtensorMapFloatOOBfill_enum { + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE = 0, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA + } CUtensorMapFloatOOBfill; + * \endcode + * Note that ::CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA can only be used when \p tensorDataType represents a floating data type. + * + * \param tensorMap - Tensor map object to create + * \param tensorDataType - Tensor data type + * \param tensorRank - Dimensionality of tensor + * \param globalAddress - Starting address of memory region described by tensor + * \param globalDim - Array containing tensor size (number of elements) along each of the \p tensorRank dimensions + * \param globalStrides - Array containing stride size (in bytes) along each of the \p tensorRank - 1 dimensions + * \param boxDim - Array containing traversal box size (number of elments) along each of the \p tensorRank dimensions. Specifies how many elements to be traversed along each tensor dimension. + * \param elementStrides - Array containing traversal stride in each of the \p tensorRank dimensions + * \param interleave - Type of interleaved layout the tensor addresses + * \param swizzle - Bank swizzling pattern inside shared memory + * \param l2Promotion - L2 promotion size + * \param oobFill - Indicate whether zero or special NaN constant must be used to fill out-of-bound elements + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTensorMapEncodeIm2col, + * ::cuTensorMapReplaceAddress + */ +CUresult CUDAAPI cuTensorMapEncodeTiled(CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, const cuuint64_t *globalStrides, const cuuint32_t *boxDim, const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); + + +/** + * \brief Create a tensor map descriptor object representing im2col memory region + * + * Creates a descriptor for Tensor Memory Access (TMA) object specified + * by the parameters describing a im2col memory layout and returns it in \p tensorMap. + * + * Tensor map objects are only supported on devices of compute capability 9.0 or higher. + * Additionally, a tensor map object is an opaque value, and, as such, should only be + * accessed through CUDA API calls. + * + * The parameters passed are bound to the following requirements: + * + * - \p tensorMap address must be aligned to 64 bytes. + * + * - \p tensorDataType has to be an enum from ::CUtensorMapDataType which is defined as: + * \code + typedef enum CUtensorMapDataType_enum { + CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0, // 1 byte + CU_TENSOR_MAP_DATA_TYPE_UINT16, // 2 bytes + CU_TENSOR_MAP_DATA_TYPE_UINT32, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_INT32, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_UINT64, // 8 bytes + CU_TENSOR_MAP_DATA_TYPE_INT64, // 8 bytes + CU_TENSOR_MAP_DATA_TYPE_FLOAT16, // 2 bytes + CU_TENSOR_MAP_DATA_TYPE_FLOAT32, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_FLOAT64, // 8 bytes + CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, // 2 bytes + CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_TFLOAT32, // 4 bytes + CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ // 4 bytes + } CUtensorMapDataType; + * \endcode + * + * - \p tensorRank must be one of dimensions 3, 4, or 5. + * + * - \p globalAddress, which specifies the starting address of the memory region described, must be 32 byte aligned when \p interleave is + * ::CU_TENSOR_MAP_INTERLEAVE_32B and 16 byte aligned otherwise. + * + * - \p globalDim array, which specifies tensor size of each of the \p tensorRank dimensions, must be non-zero and less than or + * equal to 2^32. + * + * - \p globalStrides array, which specifies tensor stride of each of the lower \p tensorRank - 1 dimensions in bytes, must be a + * multiple of 16 and less than 2^40. Additionally, the stride must be a multiple of 32 when \p interleave is ::CU_TENSOR_MAP_INTERLEAVE_32B. + * Each following dimension specified includes previous dimension stride: + * \code + globalStrides[0] = globalDim[0] * elementSizeInBytes(tensorDataType) + padding[0]; + for (i = 1; i < tensorRank - 1; i++) + globalStrides[i] = globalStrides[i – 1] * globalStrides[i] + padding[i]; + assert(globalStrides[i] >= globalDim[i]); + * \endcode + * + * - \p pixelBoxLowerCorner array specifies the coordinate offsets {D, H, W} of the bounding box from top/left/front corner. The number of + * offsets and their precision depends on the tensor dimensionality: + * - When \p tensorRank is 3, one signed offset within range [-32768, 32767] is supported. + * - When \p tensorRank is 4, two signed offsets each within range [-128, 127] are supported. + * - When \p tensorRank is 5, three offsets each within range [-16, 15] are supported. + * + * - \p pixelBoxUpperCorner array specifies the coordinate offsets {D, H, W} of the bounding box from bottom/right/back corner. The number of + * offsets and their precision depends on the tensor dimensionality: + * - When \p tensorRank is 3, one signed offset within range [-32768, 32767] is supported. + * - When \p tensorRank is 4, two signed offsets each within range [-128, 127] are supported. + * - When \p tensorRank is 5, three offsets each within range [-16, 15] are supported. + * The bounding box specified by \p pixelBoxLowerCorner and \p pixelBoxUpperCorner must have non-zero area. + * + * - \p channelsPerPixel, which specifies the number of elements which must be accessed along C dimension, must be less than or equal to 256. + * + * - \p pixelsPerColumn, which specifies the number of elements that must be accessed along the {N, D, H, W} dimensions, must be less than or + * equal to 1024. + * + * - \p elementStrides array, which specifies the iteration step along each of the \p tensorRank dimensions, must be non-zero and less + * than or equal to 8. Note that when \p interleave is ::CU_TENSOR_MAP_INTERLEAVE_NONE, the first element of this array is ignored since + * TMA doesn’t support the stride for dimension zero. + * When all elemets of \p elementStrides array is one, \p boxDim specifies the number of elements to load. However, if the \p elementStrides[i] + * is not equal to one, then TMA loads ceil( \p boxDim[i] / \p elementStrides[i]) number of elements along i-th dimension. To load N elements along + * i-th dimension, \p boxDim[i] must be set to N * \p elementStrides[i]. + * + * - \p interleave specifies the interleaved layout of type ::CUtensorMapInterleave, which is defined as: + * \code + typedef enum CUtensorMapInterleave_enum { + CU_TENSOR_MAP_INTERLEAVE_NONE = 0, + CU_TENSOR_MAP_INTERLEAVE_16B, + CU_TENSOR_MAP_INTERLEAVE_32B + } CUtensorMapInterleave; + * \endcode + * TMA supports interleaved layouts like NC/8HWC8 where C8 utilizes 16 bytes in memory assuming 2 byte per channel or NC/16HWC16 where C16 + * uses 32 bytes. + * When \p interleave is ::CU_TENSOR_MAP_INTERLEAVE_NONE and \p swizzle is not ::CU_TENSOR_MAP_SWIZZLE_NONE, the bounding box inner dimension + * (computed as \p boxDim[0] multiplied by element size derived from \p tensorDataType) must be less than or equal to the swizzle size. + * - CU_TENSOR_MAP_SWIZZLE_32B implies the bounding box inner dimension will be <= 32. + * - CU_TENSOR_MAP_SWIZZLE_64B implies the bounding box inner dimension will be <= 64. + * - CU_TENSOR_MAP_SWIZZLE_128B implies the bounding box inner dimension will be <= 128. + * + * - \p swizzle, which specifies the shared memory bank swizzling pattern, has to be of type ::CUtensorMapSwizzle which is defined as: + * \code + typedef enum CUtensorMapSwizzle_enum { + CU_TENSOR_MAP_SWIZZLE_NONE = 0, + CU_TENSOR_MAP_SWIZZLE_32B, + CU_TENSOR_MAP_SWIZZLE_64B, + CU_TENSOR_MAP_SWIZZLE_128B + } CUtensorMapSwizzle; + * \endcode + * Data is organized in specific order in global memory; however, it may not match the order in which data are accessed by application in + * the shared memory. This difference in data organization may cause bank conflicts when shared memory is accessed. In order to avoid this + * problem, data can be loaded to shard memory with shuffling across shared memory banks. + * Note that it’s expected that when \p interleave is ::CU_TENSOR_MAP_INTERLEAVE_32B, \p swizzle should be ::CU_TENSOR_MAP_SWIZZLE_32B mode. + * Other interleave modes can have any swizzling patterns. + * + * - \p l2Promotion specifies L2 fetch size which indicates the byte granurality at which L2 requests is filled from DRAM. It must be of + * type ::CUtensorMapL2promotion, which is defined as: + * \code + typedef enum CUtensorMapL2promotion_enum { + CU_TENSOR_MAP_L2_PROMOTION_NONE = 0, + CU_TENSOR_MAP_L2_PROMOTION_L2_64B, + CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CU_TENSOR_MAP_L2_PROMOTION_L2_256B + } CUtensorMapL2promotion; + * \endcode + * + * - \p oobFill, which indicates whether zero or a special NaN constant should be used to fill out-of-bound elements, must be of type + * ::CUtensorMapFloatOOBfill which is defined as: + * \code + typedef enum CUtensorMapFloatOOBfill_enum { + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE = 0, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA + } CUtensorMapFloatOOBfill; + * \endcode + * Note that ::CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA can only be used when \p tensorDataType represents a floating data type. + * + * \param tensorMap - Tensor map object to create + * \param tensorDataType - Tensor data type + * \param tensorRank - Dimensionality of tensor, needs to be at least of dimension 3 + * \param globalAddress - Starting address of memory region described by tensor + * \param globalDim - Array containing tensor size (number of elements) along each of the \p tensorRank dimensions + * \param globalStrides - Array containing stride size (in bytes) along each of the \p tensorRank - 1 dimensions + * \param pixelBoxLowerCorner - Array containing DHW dimentions of lower box corner + * \param pixelBoxUpperCorner - Array containing DHW dimentions of upper box corner + * \param channelsPerPixel - Number of channels per pixel + * \param pixelsPerColumn - Number of pixels per column + * \param elementStrides - Array containing traversal stride in each of the \p tensorRank dimensions + * \param interleave - Type of interleaved layout the tensor addresses + * \param swizzle - Bank swizzling pattern inside shared memory + * \param l2Promotion - L2 promotion size + * \param oobFill - Indicate whether zero or special NaN constant must be used to fill out-of-bound elements + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTensorMapEncodeTiled, + * ::cuTensorMapReplaceAddress + */ +CUresult CUDAAPI cuTensorMapEncodeIm2col(CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, const cuuint64_t *globalStrides, const int *pixelBoxLowerCorner, const int *pixelBoxUpperCorner, cuuint32_t channelsPerPixel, cuuint32_t pixelsPerColumn, const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); + +/** + * \brief Modify an existing tensor map descriptor with an updated global address + * + * Modifies the descriptor for Tensor Memory Access (TMA) object passed in \p tensorMap with + * an updated \p globalAddress. + * + * Tensor map objects are only supported on devices of compute capability 9.0 or higher. + * Additionally, a tensor map object is an opaque value, and, as such, should only be + * accessed through CUDA API calls. + * + * \param tensorMap - Tensor map object to modify + * \param globalAddress - Starting address of memory region described by tensor, must follow previous alignment requirements + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTensorMapEncodeTiled, + * ::cuTensorMapEncodeIm2col + */ +CUresult CUDAAPI cuTensorMapReplaceAddress(CUtensorMap *tensorMap, void *globalAddress); + +/** @} */ +/* END CUDA_TENSOR_MEMORY */ + /** * \defgroup CUDA_PEER_ACCESS Peer Context Memory Access * @@ -18945,9 +21440,18 @@ CUresult CUDAAPI cuGraphicsUnmapResources(unsigned int count, CUgraphicsResource * typedef can be picked up from the corresponding typedefs header file. For example, * cudaTypedefs.h consists of function pointer typedefs for driver APIs defined in cuda.h. * - * The API will return ::CUDA_ERROR_NOT_FOUND if the requested driver function is not - * supported on the platform, no ABI compatible driver function exists for the specified - * \p cudaVersion or if the driver symbol is invalid. + * The API will return ::CUDA_SUCCESS and set the returned \p pfn to NULL if the + * requested driver function is not supported on the platform, no ABI + * compatible driver function exists for the specified \p cudaVersion or if the + * driver symbol is invalid. + * + * It will also set the optional \p symbolStatus to one of the values in + * ::CUdriverProcAddressQueryResult with the following meanings: + * - ::CU_GET_PROC_ADDRESS_SUCCESS - The requested symbol was succesfully found based + * on input arguments and \p pfn is valid + * - ::CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND - The requested symbol was not found + * - ::CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT - The requested symbol was found but is + * not supported by cudaVersion specified * * The requested flags can be: * - ::CU_GET_PROC_ADDRESS_DEFAULT: This is the default mode. This is equivalent to @@ -18967,21 +21471,255 @@ CUresult CUDAAPI cuGraphicsUnmapResources(unsigned int count, CUgraphicsResource * \param pfn - Location to return the function pointer to the requested driver function * \param cudaVersion - The CUDA version to look for the requested driver symbol * \param flags - Flags to specify search options. + * \param symbolStatus - Optional location to store the status of the search for + * \p symbol based on \p cudaVersion. See ::CUdriverProcAddressQueryResult + * for possible values. * * \return * ::CUDA_SUCCESS, * ::CUDA_ERROR_INVALID_VALUE, - * ::CUDA_ERROR_NOT_SUPPORTED, - * ::CUDA_ERROR_NOT_FOUND + * ::CUDA_ERROR_NOT_SUPPORTED * \note_version_mixing * * \sa * ::cudaGetDriverEntryPoint */ -CUresult CUDAAPI cuGetProcAddress(const char *symbol, void **pfn, int cudaVersion, cuuint64_t flags); +CUresult CUDAAPI cuGetProcAddress(const char *symbol, void **pfn, int cudaVersion, cuuint64_t flags, CUdriverProcAddressQueryResult *symbolStatus); /** @} */ /* END CUDA_DRIVER_ENTRY_POINT */ +/** + * \defgroup CUDA_COREDUMP Coredump Attributes Control API + * + * ___MANBRIEF___ coredump attribute control functions for the low-level CUDA API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the coredump attribute control functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * Flags for choosing a coredump attribute to get/set + */ +typedef enum CUcoredumpSettings_enum { + CU_COREDUMP_ENABLE_ON_EXCEPTION = 1, + CU_COREDUMP_TRIGGER_HOST, + CU_COREDUMP_LIGHTWEIGHT, + CU_COREDUMP_ENABLE_USER_TRIGGER, + CU_COREDUMP_FILE, + CU_COREDUMP_PIPE, + CU_COREDUMP_MAX +} CUcoredumpSettings; + +/** + * \brief Allows caller to fetch a coredump attribute value for the current context + * + * Returns in \p *value the requested value specified by \p attrib. It is up to the caller + * to ensure that the data type and size of \p *value matches the request. + * + * If the caller calls this function with \p *value equal to NULL, the size of the memory + * region (in bytes) expected for \p attrib will be placed in \p size. + * + * The supported attributes are: + * - ::CU_COREDUMP_ENABLE_ON_EXCEPTION: Bool where ::true means that GPU exceptions from + * this context will create a coredump at the location specified by ::CU_COREDUMP_FILE. + * The default value is ::false unless set to ::true globally or locally, or the + * CU_CTX_USER_COREDUMP_ENABLE flag was set during context creation. + * - ::CU_COREDUMP_TRIGGER_HOST: Bool where ::true means that the host CPU will + * also create a coredump. The default value is ::true unless set to ::false globally or + * or locally. + * - ::CU_COREDUMP_LIGHTWEIGHT: Bool where ::true means that any resulting coredumps + * will not have a dump of GPU memory or non-reloc ELF images. The default value is + * ::false unless set to ::true globally or locally. + * - ::CU_COREDUMP_ENABLE_USER_TRIGGER: Bool where ::true means that a coredump can be + * created by writing to the system pipe specified by ::CU_COREDUMP_PIPE. The default + * value is ::false unless set to ::true globally or locally. + * - ::CU_COREDUMP_FILE: String of up to 1023 characters that defines the location where + * any coredumps generated by this context will be written. The default value is + * ::core.cuda.HOSTNAME.PID where ::HOSTNAME is the host name of the machine running + * the CUDA applications and ::PID is the process ID of the CUDA application. + * - ::CU_COREDUMP_PIPE: String of up to 1023 characters that defines the name of the pipe + * that will be monitored if user-triggered coredumps are enabled. The default value is + * ::corepipe.cuda.HOSTNAME.PID where ::HOSTNAME is the host name of the machine running + * the CUDA application and ::PID is the process ID of the CUDA application. + * + * \param attrib - The enum defining which value to fetch. + * \param value - void* containing the requested data. + * \param size - The size of the memory region \p value points to. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED + * + * \sa + * ::cuCoredumpGetAttributeGlobal, + * ::cuCoredumpSetAttribute, + * ::cuCoredumpSetAttributeGlobal + */ +CUresult CUDAAPI cuCoredumpGetAttribute(CUcoredumpSettings attrib, void* value, size_t *size); + +/** + * \brief Allows caller to fetch a coredump attribute value for the entire application + * + * Returns in \p *value the requested value specified by \p attrib. It is up to the caller + * to ensure that the data type and size of \p *value matches the request. + * + * If the caller calls this function with \p *value equal to NULL, the size of the memory + * region (in bytes) expected for \p attrib will be placed in \p size. + * + * The supported attributes are: + * - ::CU_COREDUMP_ENABLE_ON_EXCEPTION: Bool where ::true means that GPU exceptions from + * this context will create a coredump at the location specified by ::CU_COREDUMP_FILE. + * The default value is ::false. + * - ::CU_COREDUMP_TRIGGER_HOST: Bool where ::true means that the host CPU will + * also create a coredump. The default value is ::true. + * - ::CU_COREDUMP_LIGHTWEIGHT: Bool where ::true means that any resulting coredumps + * will not have a dump of GPU memory or non-reloc ELF images. The default value is + * ::false. + * - ::CU_COREDUMP_ENABLE_USER_TRIGGER: Bool where ::true means that a coredump can be + * created by writing to the system pipe specified by ::CU_COREDUMP_PIPE. The default + * value is ::false. + * - ::CU_COREDUMP_FILE: String of up to 1023 characters that defines the location where + * any coredumps generated by this context will be written. The default value is + * ::core.cuda.HOSTNAME.PID where ::HOSTNAME is the host name of the machine running + * the CUDA applications and ::PID is the process ID of the CUDA application. + * - ::CU_COREDUMP_PIPE: String of up to 1023 characters that defines the name of the pipe + * that will be monitored if user-triggered coredumps are enabled. The default value is + * ::corepipe.cuda.HOSTNAME.PID where ::HOSTNAME is the host name of the machine running + * the CUDA application and ::PID is the process ID of the CUDA application. + * + * \param attrib - The enum defining which value to fetch. + * \param value - void* containing the requested data. + * \param size - The size of the memory region \p value points to. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuCoredumpGetAttribute, + * ::cuCoredumpSetAttribute, + * ::cuCoredumpSetAttributeGlobal + */ +CUresult CUDAAPI cuCoredumpGetAttributeGlobal(CUcoredumpSettings attrib, void *value, size_t *size); + +/** + * \brief Allows caller to set a coredump attribute value for the current context + * + * This function should be considered an alternate interface to the CUDA-GDB environment + * variables defined in this document: https://docs.nvidia.com/cuda/cuda-gdb/index.html#gpu-coredump + * + * An important design decision to note is that any coredump environment variable values + * set before CUDA initializes will take permanent precedence over any values set with this + * this function. This decision was made to ensure no change in behavior for any users that + * may be currently using these variables to get coredumps. + * + * \p *value shall contain the requested value specified by \p set. It is up to the caller + * to ensure that the data type and size of \p *value matches the request. + * + * If the caller calls this function with \p *value equal to NULL, the size of the memory + * region (in bytes) expected for \p set will be placed in \p size. + * + * ::CU_COREDUMP_ENABLE_USER_TRIGGER and ::CU_COREDUMP_PIPE cannot be set on a per-context basis. + * + * The supported attributes are: + * - ::CU_COREDUMP_ENABLE_ON_EXCEPTION: Bool where ::true means that GPU exceptions from + * this context will create a coredump at the location specified by ::CU_COREDUMP_FILE. + * The default value is ::false. + * - ::CU_COREDUMP_TRIGGER_HOST: Bool where ::true means that the host CPU will + * also create a coredump. The default value is ::true. + * - ::CU_COREDUMP_LIGHTWEIGHT: Bool where ::true means that any resulting coredumps + * will not have a dump of GPU memory or non-reloc ELF images. The default value is + * ::false. + * - ::CU_COREDUMP_FILE: String of up to 1023 characters that defines the location where + * any coredumps generated by this context will be written. The default value is + * ::core.cuda.HOSTNAME.PID where ::HOSTNAME is the host name of the machine running + * the CUDA applications and ::PID is the process ID of the CUDA application. + * + * \param attrib - The enum defining which value to set. + * \param value - void* containing the requested data. + * \param size - The size of the memory region \p value points to. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED + * + * \sa + * ::cuCoredumpGetAttributeGlobal, + * ::cuCoredumpGetAttribute, + * ::cuCoredumpSetAttributeGlobal + */ +CUresult CUDAAPI cuCoredumpSetAttribute(CUcoredumpSettings attrib, void* value, size_t *size); + +/** + * \brief Allows caller to set a coredump attribute value globally + * + * This function should be considered an alternate interface to the CUDA-GDB environment + * variables defined in this document: https://docs.nvidia.com/cuda/cuda-gdb/index.html#gpu-coredump + * + * An important design decision to note is that any coredump environment variable values + * set before CUDA initializes will take permanent precedence over any values set with this + * this function. This decision was made to ensure no change in behavior for any users that + * may be currently using these variables to get coredumps. + * + * \p *value shall contain the requested value specified by \p set. It is up to the caller + * to ensure that the data type and size of \p *value matches the request. + * + * If the caller calls this function with \p *value equal to NULL, the size of the memory + * region (in bytes) expected for \p set will be placed in \p size. + * + * The supported attributes are: + * - ::CU_COREDUMP_ENABLE_ON_EXCEPTION: Bool where ::true means that GPU exceptions from + * this context will create a coredump at the location specified by ::CU_COREDUMP_FILE. + * The default value is ::false. + * - ::CU_COREDUMP_TRIGGER_HOST: Bool where ::true means that the host CPU will + * also create a coredump. The default value is ::true. + * - ::CU_COREDUMP_LIGHTWEIGHT: Bool where ::true means that any resulting coredumps + * will not have a dump of GPU memory or non-reloc ELF images. The default value is + * ::false. + * - ::CU_COREDUMP_ENABLE_USER_TRIGGER: Bool where ::true means that a coredump can be + * created by writing to the system pipe specified by ::CU_COREDUMP_PIPE. The default + * value is ::false. + * - ::CU_COREDUMP_FILE: String of up to 1023 characters that defines the location where + * any coredumps generated by this context will be written. The default value is + * ::core.cuda.HOSTNAME.PID where ::HOSTNAME is the host name of the machine running + * the CUDA applications and ::PID is the process ID of the CUDA application. + * - ::CU_COREDUMP_PIPE: String of up to 1023 characters that defines the name of the pipe + * that will be monitored if user-triggered coredumps are enabled. This value may not be + * changed after ::CU_COREDUMP_ENABLE_USER_TRIGGER is set to ::true. The default + * value is ::corepipe.cuda.HOSTNAME.PID where ::HOSTNAME is the host name of the machine + * running the CUDA application and ::PID is the process ID of the CUDA application. + * + * \param attrib - The enum defining which value to set. + * \param value - void* containing the requested data. + * \param size - The size of the memory region \p value points to. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_PERMITTED + * + * \sa + * ::cuCoredumpGetAttribute, + * ::cuCoredumpGetAttributeGlobal, + * ::cuCoredumpSetAttribute + */ +CUresult CUDAAPI cuCoredumpSetAttributeGlobal(CUcoredumpSettings attrib, void *value, size_t *size); + +/** @} */ /* END CUDA_COREDUMP */ + CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, const CUuuid *pExportTableId); /** @@ -19053,6 +21791,7 @@ CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, const CUuuid *pExp #undef cuMemsetD2D16Async #undef cuMemsetD2D32Async #undef cuStreamGetPriority + #undef cuStreamGetId #undef cuStreamGetFlags #undef cuStreamGetCtx #undef cuStreamWaitEvent @@ -19063,9 +21802,7 @@ CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, const CUuuid *pExp #undef cuEventRecord #undef cuEventRecordWithFlags #undef cuLaunchKernel - - - + #undef cuLaunchKernelEx #undef cuLaunchHostFunc #undef cuGraphicsMapResources #undef cuGraphicsUnmapResources @@ -19074,6 +21811,11 @@ CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, const CUuuid *pExp #undef cuStreamWriteValue64 #undef cuStreamWaitValue64 #undef cuStreamBatchMemOp + #undef cuStreamWriteValue32_v2 + #undef cuStreamWaitValue32_v2 + #undef cuStreamWriteValue64_v2 + #undef cuStreamWaitValue64_v2 + #undef cuStreamBatchMemOp_v2 #undef cuMemPrefetchAsync #undef cuLaunchCooperativeKernel #undef cuSignalExternalSemaphoresAsync @@ -19083,6 +21825,8 @@ CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, const CUuuid *pExp #undef cuStreamIsCapturing #undef cuStreamGetCaptureInfo #undef cuStreamGetCaptureInfo_v2 + #undef cuGraphInstantiateWithParams + #undef cuGraphExecUpdate #undef cuGraphUpload #undef cuGraphLaunch #undef cuDevicePrimaryCtxRelease @@ -19093,11 +21837,16 @@ CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, const CUuuid *pExp #undef cuStreamSetAttribute #undef cuStreamGetAttribute #undef cuGraphInstantiate + #undef cuGraphAddKernelNode + #undef cuGraphKernelNodeGetParams + #undef cuGraphKernelNodeSetParams + #undef cuGraphExecKernelNodeSetParams #undef cuMemMapArrayAsync #undef cuMemFreeAsync #undef cuMemAllocAsync #undef cuMemAllocFromPoolAsync #undef cuStreamUpdateCaptureDependencies + #undef cuGetProcAddress CUresult CUDAAPI cuMemHostRegister(void *p, size_t bytesize, unsigned int Flags); CUresult CUDAAPI cuGraphicsResourceSetMapFlags(CUgraphicsResource resource, unsigned int flags); @@ -19274,6 +22023,7 @@ CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, const CUuuid *pExp CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, unsigned int ui, size_t Width, size_t Height, CUstream hStream); CUresult CUDAAPI cuStreamGetPriority(CUstream hStream, int *priority); + CUresult CUDAAPI cuStreamGetId(CUstream hStream, unsigned long long *streamId); CUresult CUDAAPI cuStreamGetFlags(CUstream hStream, unsigned int *flags); CUresult CUDAAPI cuStreamGetCtx(CUstream hStream, CUcontext *pctx); CUresult CUDAAPI cuStreamWaitEvent(CUstream hStream, CUevent hEvent, unsigned int Flags); @@ -19284,9 +22034,7 @@ CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, const CUuuid *pExp CUresult CUDAAPI cuEventRecord(CUevent hEvent, CUstream hStream); CUresult CUDAAPI cuEventRecordWithFlags(CUevent hEvent, CUstream hStream, unsigned int flags); CUresult CUDAAPI cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra); - - - + CUresult CUDAAPI cuLaunchKernelEx(const CUlaunchConfig *config, CUfunction f, void **kernelParams, void **extra); CUresult CUDAAPI cuLaunchHostFunc(CUstream hStream, CUhostFn fn, void *userData); CUresult CUDAAPI cuGraphicsMapResources(unsigned int count, CUgraphicsResource *resources, CUstream hStream); CUresult CUDAAPI cuGraphicsUnmapResources(unsigned int count, CUgraphicsResource *resources, CUstream hStream); @@ -19295,6 +22043,18 @@ CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, const CUuuid *pExp CUresult CUDAAPI cuStreamWriteValue64(CUstream stream, CUdeviceptr addr, cuuint64_t value, unsigned int flags); CUresult CUDAAPI cuStreamWaitValue64(CUstream stream, CUdeviceptr addr, cuuint64_t value, unsigned int flags); CUresult CUDAAPI cuStreamBatchMemOp(CUstream stream, unsigned int count, CUstreamBatchMemOpParams *paramArray, unsigned int flags); + + CUresult CUDAAPI cuStreamWriteValue32_ptsz(CUstream stream, CUdeviceptr addr, cuuint32_t value, unsigned int flags); + CUresult CUDAAPI cuStreamWaitValue32_ptsz(CUstream stream, CUdeviceptr addr, cuuint32_t value, unsigned int flags); + CUresult CUDAAPI cuStreamWriteValue64_ptsz(CUstream stream, CUdeviceptr addr, cuuint64_t value, unsigned int flags); + CUresult CUDAAPI cuStreamWaitValue64_ptsz(CUstream stream, CUdeviceptr addr, cuuint64_t value, unsigned int flags); + CUresult CUDAAPI cuStreamBatchMemOp_ptsz(CUstream stream, unsigned int count, CUstreamBatchMemOpParams *paramArray, unsigned int flags); + + CUresult CUDAAPI cuStreamWriteValue32_v2(CUstream stream, CUdeviceptr addr, cuuint32_t value, unsigned int flags); + CUresult CUDAAPI cuStreamWaitValue32_v2(CUstream stream, CUdeviceptr addr, cuuint32_t value, unsigned int flags); + CUresult CUDAAPI cuStreamWriteValue64_v2(CUstream stream, CUdeviceptr addr, cuuint64_t value, unsigned int flags); + CUresult CUDAAPI cuStreamWaitValue64_v2(CUstream stream, CUdeviceptr addr, cuuint64_t value, unsigned int flags); + CUresult CUDAAPI cuStreamBatchMemOp_v2(CUstream stream, unsigned int count, CUstreamBatchMemOpParams *paramArray, unsigned int flags); CUresult CUDAAPI cuMemPrefetchAsync(CUdeviceptr devPtr, size_t count, CUdevice dstDevice, CUstream hStream); CUresult CUDAAPI cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams); CUresult CUDAAPI cuSignalExternalSemaphoresAsync(const CUexternalSemaphore *extSemArray, const CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS *paramsArray, unsigned int numExtSems, CUstream stream); @@ -19305,7 +22065,14 @@ CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, const CUuuid *pExp CUresult CUDAAPI cuStreamEndCapture(CUstream hStream, CUgraph *phGraph); CUresult CUDAAPI cuStreamIsCapturing(CUstream hStream, CUstreamCaptureStatus *captureStatus); CUresult CUDAAPI cuStreamGetCaptureInfo(CUstream hStream, CUstreamCaptureStatus *captureStatus_out, cuuint64_t *id_out); + CUresult CUDAAPI cuStreamGetCaptureInfo_ptsz(CUstream hStream, CUstreamCaptureStatus *captureStatus_out, cuuint64_t *id_out); CUresult CUDAAPI cuStreamGetCaptureInfo_v2(CUstream hStream, CUstreamCaptureStatus *captureStatus_out, cuuint64_t *id_out, CUgraph *graph_out, const CUgraphNode **dependencies_out, size_t *numDependencies_out); + CUresult CUDAAPI cuGraphAddKernelNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, const CUDA_KERNEL_NODE_PARAMS_v1 *nodeParams); + CUresult CUDAAPI cuGraphKernelNodeGetParams(CUgraphNode hNode, CUDA_KERNEL_NODE_PARAMS_v1 *nodeParams); + CUresult CUDAAPI cuGraphKernelNodeSetParams(CUgraphNode hNode, const CUDA_KERNEL_NODE_PARAMS_v1 *nodeParams); + CUresult CUDAAPI cuGraphExecKernelNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, const CUDA_KERNEL_NODE_PARAMS_v1 *nodeParams); + CUresult CUDAAPI cuGraphInstantiateWithParams(CUgraphExec *phGraphExec, CUgraph hGraph, CUDA_GRAPH_INSTANTIATE_PARAMS *instantiateParams); + CUresult CUDAAPI cuGraphExecUpdate(CUgraphExec hGraphExec, CUgraph hGraph, CUgraphNode *hErrorNode_out, CUgraphExecUpdateResult *updateResult_out); CUresult CUDAAPI cuGraphUpload(CUgraphExec hGraph, CUstream hStream); CUresult CUDAAPI cuGraphLaunch(CUgraphExec hGraph, CUstream hStream); CUresult CUDAAPI cuStreamCopyAttributes(CUstream dstStream, CUstream srcStream); @@ -19314,6 +22081,8 @@ CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, const CUuuid *pExp CUresult CUDAAPI cuIpcOpenMemHandle(CUdeviceptr *pdptr, CUipcMemHandle handle, unsigned int Flags); CUresult CUDAAPI cuGraphInstantiate(CUgraphExec *phGraphExec, CUgraph hGraph, CUgraphNode *phErrorNode, char *logBuffer, size_t bufferSize); + CUresult CUDAAPI cuGraphInstantiate_v2(CUgraphExec *phGraphExec, CUgraph hGraph, CUgraphNode *phErrorNode, char *logBuffer, size_t bufferSize); + CUresult CUDAAPI cuMemMapArrayAsync(CUarrayMapInfo *mapInfoList, unsigned int count, CUstream hStream); CUresult CUDAAPI cuMemFreeAsync(CUdeviceptr dptr, CUstream hStream); @@ -19321,16 +22090,18 @@ CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, const CUuuid *pExp CUresult CUDAAPI cuMemAllocFromPoolAsync(CUdeviceptr *dptr, size_t bytesize, CUmemoryPool pool, CUstream hStream); CUresult CUDAAPI cuStreamUpdateCaptureDependencies(CUstream hStream, CUgraphNode *dependencies, size_t numDependencies, unsigned int flags); + CUresult CUDAAPI cuGetProcAddress(const char *symbol, void **pfn, int cudaVersion, cuuint64_t flags); + #elif defined(__CUDA_API_PER_THREAD_DEFAULT_STREAM) -static inline CUresult cuGetProcAddress_ptsz(const char *symbol, void **funcPtr, int driverVersion, cuuint64_t flags) { +static inline CUresult cuGetProcAddress_v2_ptsz(const char *symbol, void **funcPtr, int driverVersion, cuuint64_t flags, CUdriverProcAddressQueryResult *symbolStatus) { const int procAddressMask = (CU_GET_PROC_ADDRESS_LEGACY_STREAM| CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM); if ((flags & procAddressMask) == 0) { flags |= CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM; } - return cuGetProcAddress(symbol, funcPtr, driverVersion, flags); + return cuGetProcAddress_v2(symbol, funcPtr, driverVersion, flags, symbolStatus); } -#define cuGetProcAddress cuGetProcAddress_ptsz +#define cuGetProcAddress_v2 cuGetProcAddress_v2_ptsz #endif #ifdef __cplusplus diff --git a/python/triton/tools/compile.c b/python/triton/tools/compile.c index fecb87f24777..971bf61912a7 100644 --- a/python/triton/tools/compile.c +++ b/python/triton/tools/compile.c @@ -54,9 +54,12 @@ void load_{kernel_name}() {{ /* {kernel_docstring} */ -CUresult {kernel_name}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {signature}) {{ +CUresult {kernel_name}(CUstream stream, {signature}) {{ if ({kernel_name}_func == NULL) load_{kernel_name}(); + unsigned int gX = {gridX}; + unsigned int gY = {gridY}; + unsigned int gZ = {gridZ}; void *args[{num_args}] = {{ {arg_pointers} }}; // TODO: shared memory if(gX * gY * gZ > 0) diff --git a/python/triton/tools/compile.h b/python/triton/tools/compile.h index 6e60df249e9f..d98b7063b6ae 100644 --- a/python/triton/tools/compile.h +++ b/python/triton/tools/compile.h @@ -10,7 +10,5 @@ void unload_{kernel_name}(void); void load_{kernel_name}(void); -// tt-linker: {kernel_name}:{full_signature} -CUresult{_placeholder} {kernel_name}(CUstream stream, unsigned int gX, - unsigned int gY, unsigned int gZ, - {signature}); +// tt-linker: {kernel_name}:{full_signature}:{algo_info} +CUresult{_placeholder} {kernel_name}(CUstream stream, {signature}); diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index 46f98db432f0..32138e8740a0 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -43,9 +43,11 @@ parser.add_argument("path", help="Path to Python source containing desired kernel in its scope. File will be executed.") parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", required=True) parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel") + parser.add_argument("--num-stages", "-ns", type=int, default=3, help="Number of stages (meta-parameter of the kernel)") parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel") parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename") parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True) + parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True) args = parser.parse_args() out_name = args.out_name if args.out_name else args.kernel_name @@ -58,6 +60,8 @@ mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) kernel = getattr(mod, args.kernel_name) + grid = args.grid.split(",") + assert len(grid) == 3 # validate and parse signature signature = list(map(lambda s: s.strip(" "), args.signature.split(","))) @@ -67,7 +71,8 @@ def hash_signature(signature: List[str]): m.update(" ".join(signature).encode()) return m.hexdigest()[:8] - sig_hash = hash_signature(signature) + meta_sig = f"warps{args.num_warps}xstages{args.num_stages}" + sig_hash = hash_signature(signature + [meta_sig]) def constexpr(s): try: @@ -87,6 +92,9 @@ def constexpr(s): constexprs = {i: constexpr(s) for i, s in enumerate(signature)} constexprs = {k: v for k, v in constexprs.items() if v is not None} signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constexprs} + const_sig = 'x'.join([str(v) for v in constexprs.values()]) + doc_string = [f"{kernel.arg_names[i]}={constexprs[i]}" for i in constexprs.keys()] + doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] # compile ast into cubin for h in hints.values(): @@ -96,7 +104,7 @@ def constexpr(s): config = triton.compiler.instance_descriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) for i in equal_to_1: constexprs.update({i: 1}) - ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], num_warps=args.num_warps) + ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], num_warps=args.num_warps, num_stages=args.num_stages) arg_names = [] arg_types = [] for i in signature.keys(): @@ -118,9 +126,13 @@ def constexpr(s): "full_signature": ", ".join([f"{ty_to_cpp(signature[i])} {kernel.arg_names[i]}" for i in signature.keys()]), "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names]), "num_args": len(arg_names), - "kernel_docstring": "", + "kernel_docstring": doc_string, "shared": ccinfo.shared, "num_warps": args.num_warps, + "algo_info": '_'.join([const_sig, meta_sig]), + "gridX": grid[0], + "gridY": grid[1], + "gridZ": grid[2], "_placeholder": "", } for ext in ['h', 'c']: diff --git a/python/triton/tools/link.py b/python/triton/tools/link.py index 3376a664806c..836c89c5fb19 100644 --- a/python/triton/tools/link.py +++ b/python/triton/tools/link.py @@ -15,10 +15,12 @@ class LinkerError(Exception): @dataclass class KernelLinkerMeta: + orig_kernel_name: str arg_names: Sequence[str] arg_ctypes: Sequence[str] sizes: Sequence[Union[int, None]] sig_hash: str + triton_suffix: str suffix: str num_specs: int """ number of specialized arguments """ @@ -29,8 +31,8 @@ def __init__(self) -> None: import re # [kernel_name, c signature] - self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+)") - # [name, suffix] + self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)") + # [name, hash, suffix] self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$") # [(type, name)] self.c_sig = re.compile("[\\s]*(\\w+)\\s(\\w+)[,]?") @@ -45,17 +47,19 @@ def extract_linker_meta(self, header: str): if ln.startswith("//"): m = self.linker_directives.match(ln) if _exists(m): - ker_name, c_sig = m.group(1), m.group(2) + ker_name, c_sig, algo_info = m.group(1), m.group(2), m.group(3) name, sig_hash, suffix = self._match_name(ker_name) c_types, arg_names = self._match_c_sig(c_sig) num_specs, sizes = self._match_suffix(suffix, c_sig) self._add_kernel( - name, + "_".join([name, algo_info]), KernelLinkerMeta( + orig_kernel_name=name, arg_names=arg_names, arg_ctypes=c_types, sizes=sizes, sig_hash=sig_hash, + triton_suffix=suffix, suffix=suffix, num_specs=num_specs, ), @@ -126,28 +130,48 @@ def gen_signature(m): return sig -def make_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str: +# generate declarations of kernels with meta-parameter and constant values +def make_algo_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str: return f""" -CUresult {name}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {gen_signature_with_full_args(metas[-1])}); +CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}); void load_{name}(); void unload_{name}(); """ -def make_kernel_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str: +# generate declarations of kernels with meta-parameter and constant values +def make_global_decl(meta: KernelLinkerMeta) -> str: + return f""" +CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}); +CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id); +void load_{meta.orig_kernel_name}(); +void unload_{meta.orig_kernel_name}(); + """ + + +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_default_algo_kernel(meta: KernelLinkerMeta) -> str: + src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n" + src += f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n" + src += "}\n" + return src + + +# generate dispatcher function for kernels with different integer value hints +def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str: src = f"// launcher for: {name}\n" for meta in sorted(metas, key=lambda m: -m.num_specs): - src += f"CUresult {name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {gen_signature(meta)});\n" + src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n" src += "\n" - src += f"CUresult {name}(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, {gen_signature_with_full_args(metas[-1])}){{" + src += f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{" src += "\n" for meta in sorted(metas, key=lambda m: -m.num_specs): cond_fn = lambda val, hint: f"({val} % {hint} == 0)" if hint == 16 else f"({val} == {hint})" if hint == 1 else None conds = " && ".join([cond_fn(val, hint) for val, hint in zip(meta.arg_names, meta.sizes) if hint is not None]) src += f" if ({conds})\n" arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1] - src += f" return {name}_{meta.sig_hash}_{meta.suffix}(stream, gX, gY, gZ, {', '.join(arg_names)});\n" + src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n" src += "\n" src += " return CUDA_ERROR_INVALID_VALUE;\n" src += "}\n" @@ -155,15 +179,58 @@ def make_kernel_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str: for mode in ["load", "unload"]: src += f"\n// {mode} for: {name}\n" for meta in sorted(metas, key=lambda m: -m.num_specs): - src += f"void {mode}_{name}_{meta.sig_hash}_{meta.suffix}();\n" + src += f"void {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n" src += f"void {mode}_{name}() {{" src += "\n" for meta in sorted(metas, key=lambda m: -m.num_specs): - src += f" {mode}_{name}_{meta.sig_hash}_{meta.suffix}();\n" + src += f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n" src += "}\n" return src +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str: + src = f"CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n" + src += f" assert (algo_id < (int)sizeof({meta.orig_kernel_name}_kernels));\n" + src += f" return {meta.orig_kernel_name}_kernels[algo_id](stream, {', '.join(meta.arg_names)});\n" + src += "}\n" + return src + + +# generate definition of function pointers of kernel dispatchers based on meta-parameter and constant values +def make_func_pointers(names: str, meta: KernelLinkerMeta) -> str: + # the table of hint dispatchers + src = f"typedef CUresult (*kernel_func_t)(CUstream stream, {gen_signature_with_full_args(meta)});\n" + src += f"kernel_func_t {meta.orig_kernel_name}_kernels[] = {{\n" + for name in names: + src += f" {name},\n" + src += "};\n" + return src + + +# generate definition for load/unload functions for kernels with different meta-parameter and constant values +def make_kernel_load_def(names: str, meta: KernelLinkerMeta) -> str: + src = "" + for mode in ["load", "unload"]: + src += f"void {mode}_{meta.orig_kernel_name}(void){{\n" + for name in names: + src += f" {mode}_{name}();\n" + src += "}\n\n" + return src + + +def make_get_num_algos_decl(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void);" + return src + + +def make_get_num_algos_def(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void){{\n" + src += f" return (int)sizeof({meta.orig_kernel_name}_kernels);\n" + src += "}\n" + return src + + desc = """ Triton ahead-of-time linker: @@ -198,16 +265,43 @@ def make_kernel_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str: parser.extract_linker_meta(h_str) # generate headers - decls = [make_decls(name, meta) for name, meta in parser.kernels.items()] + algo_decls = [make_algo_decls(name, meta) for name, meta in parser.kernels.items()] + meta_lists = [meta for name, meta in parser.kernels.items()] + meta = meta_lists[0][0] + get_num_algos_decl = make_get_num_algos_decl(meta) + global_decl = make_global_decl(meta) with args.out.with_suffix(".h").open("w") as fp: - fp.write("#include \n" + "\n".join(decls)) + out = "#include \n" + out += "\n".join(algo_decls) + out += "\n" + out += get_num_algos_decl + out += "\n" + out += global_decl + fp.write(out) # generate source - defs = [make_kernel_dispatcher(name, meta) for name, meta in parser.kernels.items()] + defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()] + names = [name for name in parser.kernels.keys()] + func_pointers_def = make_func_pointers(names, meta) + meta_const_def = make_kernel_meta_const_dispatcher(meta) + load_unload_def = make_kernel_load_def(names, meta) + get_num_algos_def = make_get_num_algos_def(meta) + default_algo_kernel = make_default_algo_kernel(meta) with args.out.with_suffix(".c").open("w") as fp: out = "" out += "#include \n" out += "#include \n" + out += "#include \n" out += "\n" out += "\n".join(defs) + out += "\n" + out += func_pointers_def + out += "\n" + out += get_num_algos_def + out += "\n" + out += meta_const_def + out += "\n" + out += load_unload_def + out += "\n" + out += default_algo_kernel fp.write(out) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 4ba9c2af9ad1..ea72b426eb6a 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -174,14 +174,21 @@ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), ] if torch.version.hip is None else [ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 16}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 32}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 32}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 16}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 32}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 16}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 32}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=0), ], key=['M', 'N', 'K'], ) diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index 9ffa13e8f4b0..6f131c963834 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -260,7 +260,7 @@ def forward(ctx, x, normalized_shape, weight, bias, eps): # enqueue kernel _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd, x_arg.stride(0), N, eps, - BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) + BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) ctx.save_for_backward(x, weight, bias, mean, rstd) ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps @@ -296,7 +296,7 @@ def backward(ctx, dy): # accumulate partial sums in separate kernel _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N, BLOCK_SIZE_M=32, - BLOCK_SIZE_N=128) + BLOCK_SIZE_N=128, num_ctas=1) return dx, None, dw, db, None @@ -356,19 +356,21 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c quantiles = [0.5, 0.2, 0.8] # utility functions if provider == 'triton': - y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps) + def y_fwd(): return layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 if provider == 'torch': - y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) + def y_fwd(): return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 if provider == 'apex': - apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype) - y_fwd = lambda: apex_layer_norm(x) + apex_layer_norm = apex.normalization.FusedLayerNorm( + w_shape).to(x.device).to(x.dtype) + + def y_fwd(): return apex_layer_norm(x) # noqa: F811, E704 # forward pass if mode == 'forward': gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6 ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500) # backward pass if mode == 'backward': - gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6 + def gbps(ms): return 3 * x.numel() * x.element_size() / ms * 1e-6 # noqa: F811, E704 y = y_fwd() ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), quantiles=quantiles, grad_to_none=[x], rep=500) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 74e6d39aed07..789a68a981f7 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -23,27 +23,93 @@ def max_fn(x, y): return tl.math.max(x, y) +@triton.jit +def _attn_fwd_inner( + acc, l_i, m_i, q, + K_block_ptr, V_block_ptr, + start_m, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, + N_CTX, + pre_load_v: tl.constexpr, +): + # range of values handled by this stage + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + # causal = False + else: + lo, hi = 0, N_CTX + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = tl.where(mask, qk, float("-inf")) + qk += tl.dot(q, k) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(tl.float16), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True + ], + key=['N_CTX', 'STAGE', 'BLOCK_DMODEL'], +) + @triton.jit -def _fwd_kernel( - Q, K, V, sm_scale, - L, - Out, +def _attn_fwd( + Q, K, V, sm_scale, M, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, - Z, H, N_CTX, P_SEQ, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + Z, H, + N_CTX, + BLOCK_DMODEL: tl.constexpr, + STAGE: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - IS_CAUSAL: tl.constexpr, + pre_load_v: tl.constexpr, ): start_m = tl.program_id(0) off_hz = tl.program_id(1) - q_offset = off_hz * stride_qh - kv_offset = off_hz * stride_kh + qkv_offset = off_hz * stride_qh Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, + base=Q + qkv_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), @@ -51,16 +117,16 @@ def _fwd_kernel( order=(1, 0) ) K_block_ptr = tl.make_block_ptr( - base=K + kv_offset, - shape=(BLOCK_DMODEL, N_CTX + P_SEQ), + base=K + qkv_offset, + shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1) ) V_block_ptr = tl.make_block_ptr( - base=V + kv_offset, - shape=(N_CTX + P_SEQ, BLOCK_DMODEL), + base=V + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), @@ -71,55 +137,53 @@ def _fwd_kernel( offs_n = tl.arange(0, BLOCK_N) # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # scale sm_scale by log_2(e) and use # 2^x instead of exp in the loop because CSE and LICM # don't work as expected with `exp` in the loop qk_scale = sm_scale * 1.44269504 - # load q: it will stay in SRAM throughout + # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs q = tl.load(Q_block_ptr) q = (q * qk_scale).to(tl.float16) - # loop over k, v and update accumulator - lo = 0 - hi = P_SEQ + (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + P_SEQ - for start_n in range(lo, hi, BLOCK_N): - # -- load k, v -- - k = tl.load(K_block_ptr) - v = tl.load(V_block_ptr) - # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - if IS_CAUSAL: - qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - qk += tl.dot(q, k) - # -- compute scaling constant --- - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - alpha = tl.math.exp2(m_i - m_i_new) - p = tl.math.exp2(qk - m_i_new[:, None]) - # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - acc *= acc_scale[:, None] - acc += tl.dot(p.to(tl.float16), v) - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - # update pointers - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - # write back l and m + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, + BLOCK_M, BLOCK_DMODEL, BLOCK_N, + 4 - STAGE, offs_m, offs_n, + N_CTX, pre_load_v, + ) + # stage 2: on-band + if STAGE & 2: + # barrier makes it easier for compielr to schedule the + # two loops independently + tl.debug_barrier() + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, + BLOCK_M, BLOCK_DMODEL, BLOCK_N, + 2, offs_m, offs_n, + N_CTX, pre_load_v, + ) + # epilogue + # write back m acc = acc / l_i[:, None] - l_ptrs = L + off_hz * N_CTX + offs_m - tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i + tl.math.log2(l_i)) # write back O O_block_ptr = tl.make_block_ptr( - base=Out + q_offset, + base=Out + qkv_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) - tl.store(O_block_ptr, acc.to(tl.float16)) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) @triton.jit @@ -149,8 +213,8 @@ def _bwd_kernel( stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, - Z, H, N_CTX, - num_block, + Z, H, N_CTX, P_SEQ, + num_block_q, num_block_kv, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, CAUSAL: tl.constexpr, @@ -161,17 +225,17 @@ def _bwd_kernel( qk_scale = sm_scale * 1.44269504 # offset pointers for batch/head Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_qz + off_h * stride_qh - V += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_h * stride_kh + V += off_z * stride_vz + off_h * stride_vh DO += off_z * stride_qz + off_h * stride_qh DQ += off_z * stride_qz + off_h * stride_qh - DK += off_z * stride_qz + off_h * stride_qh - DV += off_z * stride_qz + off_h * stride_qh + DK += off_z * stride_kz + off_h * stride_kh + DV += off_z * stride_vz + off_h * stride_vh # See fwd pass above for explanation. qk_scale = sm_scale * 1.44269504 - for start_n in range(0, num_block): + for start_n in range(0, num_block_kv): if CAUSAL: - lo = start_n * BLOCK_M + lo = tl.math.max(start_n * BLOCK_M - P_SEQ, 0) else: lo = 0 # initialize row/col offsets @@ -188,23 +252,23 @@ def _bwd_kernel( # pointer to row-wise quantities in value-like data D_ptrs = D + off_hz * N_CTX l_ptrs = L + off_hz * N_CTX - # initialize dv amd dk - dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # initialize dk amd dv dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # k and v stay in SRAM throughout k = tl.load(k_ptrs) v = tl.load(v_ptrs) # loop over rows - for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + for start_m in range(lo, num_block_q * BLOCK_M, BLOCK_M): offs_m_curr = start_m + offs_m # load q, k, v, do on-chip q = tl.load(q_ptrs) # recompute p = softmax(qk, dim=-1).T if CAUSAL: - qk = tl.dot(q, tl.trans(k), out_dtype=tl.float32) - qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + qk = tl.where(P_SEQ + offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf")) else: - qk = tl.dot(q, tl.trans(k), out_dtype=tl.float32) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) l_i = tl.load(l_ptrs + offs_m_curr) p = tl.math.exp2(qk * qk_scale - l_i[:, None]) # compute dv @@ -227,10 +291,10 @@ def _bwd_kernel( q_ptrs += BLOCK_M * stride_qm do_ptrs += BLOCK_M * stride_qm # write-back - dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - tl.store(dv_ptrs, dv) + dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) tl.store(dk_ptrs, dk) + tl.store(dv_ptrs, dv) @triton.jit def _bwd_kernel_dk_dv( @@ -456,40 +520,43 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} o = torch.empty_like(q) - BLOCK_M = 128 if torch.version.hip is None: + BLOCK_M = 128 BLOCK_N = 64 if Lk <= 64 else 32 num_stages = 4 if Lk <= 64 else 3 - else: - BLOCK_N = 64 - num_stages = 1 - num_warps = 4 - grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) - L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2] - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel[grid]( - q, k, v, sm_scale, - L, - o, + num_warps = 4 if Lk <= 64 else 8 + + stage = 3 if causal else 1 + grid = lambda META: ( + triton.cdiv(q.shape[2], META['BLOCK_M']), + q.shape[0] * q.shape[1], + 1 + ) + M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + + _attn_fwd[grid]( + q, k, v, sm_scale, M, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], q.shape[2], P_SEQ, - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, - IS_CAUSAL=causal, - num_warps=num_warps, - num_stages=num_stages) + q.shape[0], q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + STAGE=stage, + ) + + ## restore the grid for bwd kernel + best_config = _attn_fwd.get_best_config(N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk) + block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1]) + grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1) - ctx.save_for_backward(q, k, v, o, L) + ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid ctx.sm_scale = sm_scale ctx.BLOCK_DMODEL = Lk ctx.causal = causal ctx.split_kernel = split_kernel - ctx.P_SEQ = P_SEQ return o @staticmethod @@ -560,7 +627,7 @@ def backward(ctx, do): v.stride(0), v.stride(1), v.stride(2), v.stride(3), q.shape[0], q.shape[1], q.shape[2], BLOCK_M=2*BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4, waves_per_eu=1, num_stages=1, ) # print(h.asm["ttgir"]) @@ -569,23 +636,38 @@ def backward(ctx, do): attention = _attention.apply -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ', - [(4, 48, 1024, 64, 128), - (4, 48, 2048, 64, 128), - (4, 48, 4096, 64, 128), - (4, 48, 8192, 64, 128), - (4, 48, 16384, 64, 128) +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', + [(4, 48, 1024, 64), + (4, 48, 2048, 64), + (4, 48, 4096, 64), + (4, 48, 1024, 128), + (4, 48, 2048, 128), + (4, 48, 4096, 128), + #(4, 48, 8192, 64), + #(4, 48, 16384, 64) ]) @pytest.mark.parametrize('causal', [False, True]) -def test_op_fwd(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16): +def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): torch.manual_seed(20) - q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - k = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - sm_scale = q.shape[-1] ** (-0.5) + q = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + k = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + v = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + sm_scale = 0.5 dout = torch.randn_like(q) # reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ) + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) p = torch.matmul(q, k.transpose(2, 3)) * sm_scale if causal: p[:, :, M == 0] = float("-inf") @@ -597,23 +679,23 @@ def test_op_fwd(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16): assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ', - [(4, 48, 1024, 64, 0), - (4, 48, 2048, 64, 0), - (4, 48, 4096, 64, 0), - (1, 16, 8192, 64, 0), +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', + [(4, 48, 1024, 64), + (4, 48, 2048, 64), + (4, 48, 4096, 64), + (1, 16, 8192, 64), ]) -def test_op_bwd(Z, H, N_CTX, D_HEAD, P_SEQ, dtype=torch.float16): +def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): torch.manual_seed(20) causal = True q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - k = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - sm_scale = q.shape[-1] ** (-0.5) + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + sm_scale = 0,5 split_kernel = True dout = torch.randn_like(q) # reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ) + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) p = torch.matmul(q, k.transpose(2, 3)) * sm_scale if causal: p[:, :, M == 0] = float("-inf") @@ -653,19 +735,33 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, P_SEQ, dtype=torch.float16): FLASH_VER = None HAS_FLASH = FLASH_VER is not None -BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +BATCH, N_HEADS, N_CTX= 4, 48, 4096 # vary seq length for fixed head and batch=4 -configs = [triton.testing.Benchmark( - x_names=['N_CTX'], - x_vals=[2**i for i in range(10, 15)], - line_arg='provider', - line_vals=['triton'] + (['flash'] if HAS_FLASH else []), - line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), - styles=[('red', '-'), ('blue', '-')], - ylabel='ms', - plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', - args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode, 'causal': causal} -) for mode in ['fwd', 'bwd'] for causal in [False, True]] +configs = [] +for mode in ['fwd', 'bwd']: + for causal in [False, True]: + if mode == 'bwd' and causal == False: + continue + for D_HEAD in [64, 128]: + if mode == 'bwd' and D_HEAD == 128: + continue + configs.append(triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(10, 15)], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}', + args={ + 'H': N_HEADS, + 'BATCH': BATCH, + 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + 'causal': causal}) + ) @triton.testing.perf_report(configs) diff --git a/python/tutorials/09-experimental-tma-matrix-multiplication.py b/python/tutorials/09-experimental-tma-matrix-multiplication.py new file mode 100644 index 000000000000..8a79720c79a0 --- /dev/null +++ b/python/tutorials/09-experimental-tma-matrix-multiplication.py @@ -0,0 +1,200 @@ +""" +Matrix Multiplication with TMA (Experimental) +================================================ +In this tutorial, you will write a very short high-performance multiplication kernel that achieves +performance on parallel with cuBLAS. +""" + +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + +if torch.cuda.get_device_capability()[0] < 9: + import sys + print("Skipping TMA benchmark for GPU with compute capability < 9") + sys.exit(0) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=2), + # triton.Config({'BLOCK_SIZE_M': 512, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=4), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, z_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_zm, stride_zn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, + B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr +): + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + block_offset_m = pid_m * BLOCK_SIZE_M + block_offset_n = pid_n * BLOCK_SIZE_N + + a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), order=(A_ORDER_0, A_ORDER_1)) + b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(B_ORDER_0, B_ORDER_1)) + z = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + offs_m = block_offset_m + tl.arange(0, BLOCK_SIZE_M) + offs_n = block_offset_n + tl.arange(0, BLOCK_SIZE_N) + z_ptrs = z_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn + mask = (offs_m < M)[:, None] & (offs_n < N)[None, :] + + for k in range(0, K, BLOCK_SIZE_K): + a = tl.load(a_tile_ptr) + b = tl.load(b_tile_ptr) + z += tl.dot(a, b) + a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K]) + b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0]) + + z = z.to(tl.float16) + + tl.store(z_ptrs, z, mask=mask) + + +def matmul(a, b, a_order, b_order): + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + K, N = b.shape + + z = torch.empty((M, N), device=a.device, dtype=torch.float16) + + def grid(META): + return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + matmul_kernel[grid](a_ptr=a, b_ptr=b, z_ptr=z, + M=M, N=N, K=K, + stride_am=a.stride(0), stride_ak=a.stride(1), + stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_zm=z.stride(0), stride_zn=z.stride(1), + A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], + B_ORDER_0=b_order[0], B_ORDER_1=b_order[1] + ) + return z + + +problem_list = [ + [2048, 512, 512, False, True], + [2048, 1024, 1024, False, False], + [2048, 2048, 2048, True, False], + [2048, 4096, 4096, True, True], +] + + +def test_matmul(): + for case in problem_list: + M, N, K, TRANS_A, TRANS_B = case + print(M, N, K, TRANS_A, TRANS_B) + if (TRANS_A): + a = torch.randn((K, M), device='cuda', dtype=torch.float16).T + a_order = [0, 1] + else: + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + a_order = [1, 0] + + if (TRANS_B): + b = torch.randn((N, K), device='cuda', dtype=torch.float16).T + b_order = [0, 1] + else: + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + b_order = [1, 0] + + golden = torch.matmul(a, b) + z = matmul(a, b, a_order, b_order) + + golden = torch.nn.functional.normalize(golden) + z = torch.nn.functional.normalize(z) + torch.set_printoptions(profile="full") + assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + # argument names to use as an x-axis for the plot + x_names=['M', 'N', 'K', 'TRANS_A', 'TRANS_B'], + x_vals=problem_list, # different possible values for `x_name` + line_arg='provider', + # argument name whose value corresponds to a different line in the plot + # possible values for `line_arg`` + line_vals=['cublas', 'triton'], + # label name for the lines + line_names=["cuBLAS", "Triton"], + # line styles + styles=[('green', '-'), ('green', '--'), + ('blue', '-'), ('blue', '--')], + ylabel="TFLOPS", # label name for the y-axis + plot_name="matmul-performance", + # name for the plot. Used also as a file name for saving the plot. + args={}, + ) +) +def benchmark(M, N, K, TRANS_A, TRANS_B, provider): + if (TRANS_A): + a = torch.randn((K, M), device='cuda', dtype=torch.float16).T + a_order = [0, 1] + else: + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + a_order = [1, 0] + + if (TRANS_B): + b = torch.randn((N, K), device='cuda', dtype=torch.float16).T + b_order = [0, 1] + else: + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + b_order = [1, 0] + + quantiles = [0.5, 0.2, 0.8] + if provider == 'cublas': + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: matmul(a, b, a_order, b_order), rep=100, quantiles=quantiles, fast_flush=False) + + def perf(ms): + return 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +test_matmul() +benchmark.run(show_plots=False, print_data=True) diff --git a/python/tutorials/10-experimental-tma-store-matrix-multiplication.py b/python/tutorials/10-experimental-tma-store-matrix-multiplication.py new file mode 100644 index 000000000000..37d58863d083 --- /dev/null +++ b/python/tutorials/10-experimental-tma-store-matrix-multiplication.py @@ -0,0 +1,179 @@ +""" +Matrix Multiplication with TMA Store (Experimental) +================================================ +In this tutorial, you will write a very short high-performance multiplication kernel that achieves +performance on parallel with cuBLAS. +""" + +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + +if torch.cuda.get_device_capability()[0] < 9: + import sys + print("Skipping TMA benchmark for GPU with compute capability < 9") + sys.exit(0) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 512, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=4), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + block_offset_m = pid_m * BLOCK_SIZE_M + block_offset_n = pid_n * BLOCK_SIZE_N + + a_tile_ptr = tl.make_block_ptr( + base=a_ptr, shape=( + M, K), strides=( + stride_am, stride_ak), offsets=( + block_offset_m, 0), block_shape=( + BLOCK_SIZE_M, BLOCK_SIZE_K), order=( + 1, 0)) + b_tile_ptr = tl.make_block_ptr( + base=b_ptr, shape=( + K, N), strides=( + stride_bk, stride_bn), offsets=( + 0, block_offset_n), block_shape=( + BLOCK_SIZE_K, BLOCK_SIZE_N), order=( + 0, 1)) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, K, BLOCK_SIZE_K): + a = tl.load(a_tile_ptr) + b = tl.load(b_tile_ptr) + accumulator += tl.dot(a, b) + a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K]) + b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0]) + + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), + offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0)) + + tl.store(c_block_ptr, accumulator) + + +def matmul(a, b): + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + K, N = b.shape + assert ( + K % 32 == 0 + ), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K" + + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + + def grid(META): + return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + + matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, + M=M, N=N, K=K, + stride_am=a.stride(0), stride_ak=a.stride(1), + stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_cm=c.stride(0), stride_cn=c.stride(1)) + return c + + +a = torch.randn((512, 512), device='cuda', dtype=torch.float16) +b = torch.randn((512, 512), device='cuda', dtype=torch.float16).T +c = matmul(a, b) +c = torch.nn.functional.normalize(c) + +golden = torch.nn.functional.normalize(torch.matmul(a, b)) + +torch.set_printoptions(profile="full") +assert_close( + c, + golden, + rtol=1e-2, + atol=1e-3, + check_dtype=False) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + # argument names to use as an x-axis for the plot + x_names=['M', 'N', 'K'], + x_vals=[ + [2048, 512, 512], + [2048, 1024, 1024], + [2048, 2048, 2048], + [2048, 4096, 4096], + [2048, 8192, 8192] + ], # different possible values for `x_name` + line_arg='provider', + # argument name whose value corresponds to a different line in the plot + # possible values for `line_arg`` + line_vals=['cublas', 'triton'], + # label name for the lines + line_names=["cuBLAS", "Triton"], + # line styles + styles=[('green', '-'), ('green', '--'), + ('blue', '-'), ('blue', '--')], + ylabel="TFLOPS", # label name for the y-axis + plot_name="matmul-performance", + # name for the plot. Used also as a file name for saving the plot. + args={}, + ) +) +def benchmark(M, N, K, provider): + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + b = torch.randn((N, K), device='cuda', dtype=torch.float16).T + quantiles = [0.5, 0.2, 0.8] + if provider == 'cublas': + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False) + + def perf(ms): + return 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +benchmark.run(show_plots=False, print_data=True) diff --git a/scripts/amd/docker_run.sh b/scripts/amd/docker_run.sh index b8cf02a9ee94..202416c8f365 100755 --- a/scripts/amd/docker_run.sh +++ b/scripts/amd/docker_run.sh @@ -1,6 +1,6 @@ set -o xtrace -alias drun='sudo docker run -it --rm --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined' +DRUN='sudo docker run -it --rm --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined' # DEVICES="--gpus all" DEVICES="--device=/dev/kfd --device=/dev/dri" @@ -21,7 +21,7 @@ CONTAINER_NAME=triton # start new container docker stop $CONTAINER_NAME docker rm $CONTAINER_NAME -CONTAINER_ID=$(drun -d -w $WORK_DIR --name $CONTAINER_NAME $MEMORY $VOLUMES $DEVICES $IMAGE_NAME) +CONTAINER_ID=$($DRUN -d -w $WORK_DIR --name $CONTAINER_NAME $MEMORY $VOLUMES $DEVICES $IMAGE_NAME) echo "CONTAINER_ID: $CONTAINER_ID" # docker cp . $CONTAINER_ID:$WORK_DIR # docker exec $CONTAINER_ID bash -c "bash scripts/amd/run.sh" diff --git a/scripts/amd/fix_so.sh b/scripts/amd/fix_so.sh new file mode 100755 index 000000000000..3f5ca42c2410 --- /dev/null +++ b/scripts/amd/fix_so.sh @@ -0,0 +1,104 @@ +#!/bin/bash +#From https://github.com/pytorch/builder/blob/main/manywheel/build_common.sh +WHEELHOUSE_DIR=/artifacts +PATCHELF_BIN=patchelf +ROCM_LIB=third_party/rocm/lib +ROCM_LD=third_party/rocm/llvm/bin +PREFIX=triton + +fname_without_so_number() { + LINKNAME=$(echo $1 | sed -e 's/\.so.*/.so/g') + echo "$LINKNAME" +} + +replace_needed_sofiles() { + find $1 -name '*.so*' -o -name 'ld.lld' | while read sofile; do + origname=$2 + patchedname=$3 + if [[ "$origname" != "$patchedname" ]] || [[ "$DESIRED_CUDA" == *"rocm"* ]]; then + set +e + origname=$($PATCHELF_BIN --print-needed $sofile | grep "$origname.*") + ERRCODE=$? + set -e + if [ "$ERRCODE" -eq "0" ]; then + echo "patching $sofile entry $origname to $patchedname" + $PATCHELF_BIN --replace-needed $origname $patchedname $sofile + fi + fi + done +} + +mkdir -p "/tmp_dir" +pushd /tmp_dir +#for pkg in /$WHEELHOUSE_DIR/*triton*linux*.whl; do +for pkg in /$WHEELHOUSE_DIR/*triton*.whl; do + echo "Modifying $pkg" + rm -rf tmp + mkdir -p tmp + cd tmp + cp $pkg . + unzip -q $(basename $pkg) + rm -f $(basename $pkg) + $PATCHELF_BIN --set-rpath ${LD_SO_RPATH:-'$ORIGIN:$ORIGIN/../../lib'} $PREFIX/$ROCM_LD/ld.lld + $PATCHELF_BIN --print-rpath $PREFIX/$ROCM_LD/ld.lld + # Modify libtriton.so as it sits in _C directory apart from it'd dependencies + find $PREFIX/_C -type f -name "*.so*" | while read sofile; do + echo "Setting rpath of $sofile" + $PATCHELF_BIN --set-rpath ${C_SO_RPATH:-'$ORIGIN:$ORIGIN/'../$ROCM_LIB} ${FORCE_RPATH:-} $sofile + $PATCHELF_BIN --print-rpath $sofile + done + + # All included dependencies are included in a single lib directory + deps=() + deps_soname=() + while read sofile; do + echo "Setting rpath of $sofile to ${LIB_SO_RPATH:-'$ORIGIN'}" + $PATCHELF_BIN --set-rpath ${LIB_SO_RPATH:-'$ORIGIN'} ${FORCE_RPATH:-} $sofile + $PATCHELF_BIN --print-rpath $sofile + deps+=("$sofile") + deps_soname+=("$(basename $sofile)") + done < <(find $PREFIX/$ROCM_LIB -type f -name "*.so*") + + # Get list of patched names in our third_party/rocm/lib directory + patched=() + for filepath in "${deps[@]}"; do + filename=$(basename $filepath) + destpath=$PREFIX/$ROCM_LIB/$filename + if [[ "$filepath" != "$destpath" ]]; then + cp $filepath $destpath + fi + patchedpath=$(fname_without_so_number $destpath) + patchedname=$(basename $patchedpath) + if [[ "$destpath" != "$patchedpath" ]]; then + mv $destpath $patchedpath + fi + patched+=("$patchedname") + echo "Copied $filepath to $patchedpath" + done + + # Go through all required shared objects and see if any of our other objects are dependants. If so, replace so.ver wth so + for ((i=0;i<${#deps[@]};++i)); do + echo "replacing "${deps_soname[i]} ${patched[i]} + replace_needed_sofiles $PREFIX/$ROCM_LIB ${deps_soname[i]} ${patched[i]} + replace_needed_sofiles $PREFIX/_C ${deps_soname[i]} ${patched[i]} + replace_needed_sofiles $PREFIX/$ROCM_LD ${deps_soname[i]} ${patched[i]} + + done + + # Re-bundle whl with so adjustments + zip -rqy $(basename $pkg) * + + # Add manylinux2014 to whl name for pypi. I believe we have met the criteria for manylinux based on our + # toolchain and rpath changes to make each whl self contained and built with manylinux versions of python + if [[ -z "${MANYLINUX_VERSION}" ]]; then + newpkg=$pkg + else + newpkg=$(echo $pkg | sed -e "s/\linux_x86_64/${MANYLINUX_VERSION}/g") + fi + + # Remove original whl + rm -f $pkg + + # Move rebuilt whl to original location with new name. + mv $(basename $pkg) $newpkg +done diff --git a/scripts/amd/gemm/README.md b/scripts/amd/gemm/README.md new file mode 100644 index 000000000000..8b547af3f608 --- /dev/null +++ b/scripts/amd/gemm/README.md @@ -0,0 +1,71 @@ +# GEMM tuning script v2 + +This is the v2 version of the gemm tuning script, which is based on @scxiao's v1 (https://github.com/ROCmSoftwarePlatform/triton/pull/309) and @alefimov-amd's thread pool https://github.com/ROCmSoftwarePlatform/triton/pull/310 + +### Main features +- `rocprof` is used to measure the time for kernels in the full tuning space +- Each kernel is executed 10 times and the execution time of the last instance is used +- All kernels are compiled in parallel +- Two modes for correctness checking + - During tuning, check correctness with the best perf_config for the current gemm size + - Without tuning, check correctness based on the tuning results, which includes best perf_config for each gemm size +- The process takes about 30 - 40 minutes for the full tuning space with ~15000 configs +- Limitations + - For now, only support fp16 as inputs. It should be trivial to extend to other types, but may require some work for mixed inputs + +### Usage +Go to the script dir +```bash +cd triton/scripts/amd/gemm/ +``` + +1. Tune gemm sizes given in a yaml file and check correctness on the way +```bash +python tune_gemm.py --gemm_size_file input_gemm_sizes.yaml --compare +``` + +2. Tune a single gemm size +```bash +python tune_gemm.py -m 16 -n 16 -k 16 +``` + +3. Choose the file to store tuning results +```bash +python tune_gemm.py --gemm_size_file input_gemm_sizes.yaml --tuning_results_file output_tuning.yaml +``` + +4. Only check correctness given the tuning results +```bash +python tune_gemm.py --gemm_size_file output_tuning.yaml --compare_wo_tuning +``` +Note that the tuning results file are provided as the `gemm_size_file` in this scenario. + +### Overview of implementations + +Workflow of the tuning process +1. Generate the full tuning space. For now the `range`s for each tuning parameter are hard-coded +2. Prune the tuning space according to the current GEMM size and some rules + - BLOCK_SIZE must be equal or larger than the mfma instruction size. + - SPLIT_K * BLOCK_SIZE_K must divide K. Therefore, we do not need EVEN_K in the kernel. + - When split-k is not needed, i.e. both M and N are large, it must be 1 + - GROUP_M * BLOCK_SIZE_M must be smaller than M. Otherwise, GROUP_M must be 1 + - When BLOCK_SIZE_K = 128, neither BLOCK_SIZE_M or BLOCK_SIZE_N can be 128. Otherwise too much LDS will be required. **Needs further investigation** +3. Open a file `generated_kernel{M}{N}{K}.py` and write the following into the file + 1. For each config in the pruned space, generate a kernel with name `matmul_kernel_{configStr}`, where `configStr` contains the gemm size and the tuning parameters. + 2. Generate `matmul` function for each config in a similar way + 3. Generate `try_config` functions for each `matmul` function. + 4. Generate `test_gemm`, which does + 1. Add all `try_config` functions in the thread_pool by `thread_pool.apply_async(try_config)`. This is used to compile all kernels in parallel. + 2. Call each `matmul` function in a for loop of 10 iterations + 5. Generate `main` function +4. Run the generated script with 16 workers. This will compile all kernels in parallel. +5. Invoke `rocprof` on the generated script +6. Post process `results.csv` by extract the execution time of the last instance of each kernel. Pick the best one, write to file, and return. + +### Known issues +On some node, I saw the following runtime error +``` +:0:rocdevice.cpp :2776: 7321835745146 us: 1401 : [tid:0x7fc930830700] Callback: Queue 0x7fc9b7200000 aborting with error : HSA_STATUS_ERROR_INVALID_ISA: The instruction set architecture is invalid. code: 0x100f +``` +It's hard to reproduce the error. **Needs further investigation** +- https://github.com/ROCmSoftwarePlatform/frameworks-internal/issues/6011 diff --git a/scripts/amd/gemm/matmul.py b/scripts/amd/gemm/matmul.py index 11a937b3e3b0..e04073069783 100644 --- a/scripts/amd/gemm/matmul.py +++ b/scripts/amd/gemm/matmul.py @@ -12,8 +12,38 @@ import os import subprocess + + # global flag to indicate whether using the full tuing space -tuning_full_space = False +tuning_full_space = True + +# pruned some unreasonable config +def prune_configs(configs, named_args): + # call only for full tuning space + if not tuning_full_space: + return configs + + SIZE_M = named_args["a_ptr"].shape[0] + SIZE_N = named_args["b_ptr"].shape[1] + SIZE_K = named_args["a_ptr"].shape[1] + + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K =\ + kw["BLOCK_SIZE_M"], kw["BLOCK_SIZE_N"], kw["BLOCK_SIZE_K"] + SPLIT_K = kw["SPLIT_K"] + if SIZE_M <=32 and BLOCK_SIZE_M != 32: + continue + if SIZE_N <=32 and BLOCK_SIZE_N != 32: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(SIZE_M, SIZE_N, SIZE_K): + continue + pruned_configs.append(config) + + return pruned_configs + def get_full_tuning_space(use_split_k): configs = [] @@ -22,19 +52,23 @@ def get_full_tuning_space(use_split_k): block_mn_range = [32, 64, 128] block_k_range = [32, 64] - split_k_range = [2, 4, 5, 8, 10] + split_k_range = [1, 2, 4, 5, 8, 10] num_warps_range = [1, 2, 4, 8] group_m_range = [1, 4, 8] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [0] for block_m in block_mn_range: for block_n in block_mn_range: for block_k in block_k_range: for num_warps in num_warps_range: for group_m in group_m_range: - configs.append(triton.Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m}, num_stages=1, num_warps=num_warps)) - if use_split_k: - for split_k in split_k_range: - configs.append(triton.Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k}, num_stages=1, num_warps=num_warps)) + for split_k in split_k_range: + for num_stages in num_stage_range: + configs.append(triton.Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k}, num_stages=num_stages, num_warps=num_warps)) + return configs @@ -59,6 +93,11 @@ def get_full_tuning_space(use_split_k): triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'SPLIT_K': 10}, num_stages=1, num_warps=1), ], key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': prune_configs, + 'perf_model': None, + "top_k": None + }, ) @triton.heuristics({ 'EVEN_K': lambda args: args['K'] % (args['BLOCK_SIZE_K'] * args['SPLIT_K']) == 0, @@ -106,7 +145,10 @@ def matmul_kernel_splitK( # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers # See above `Pointer Arithmetics` section for details - offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + if SPLIT_K == 1: + offs_k = tl.arange(0, BLOCK_SIZE_K) + else: + offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) if torch.version.hip is None: offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N @@ -157,111 +199,6 @@ def matmul_kernel_splitK( tl.atomic_add(c_ptrs, c, mask=c_mask) -# Kernel no split K -@triton.autotune( - configs= get_full_tuning_space(False) if tuning_full_space else [ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=2), - ], - key=['M', 'N', 'K'], -) -@triton.heuristics({ - 'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0, -}) -@triton.jit -def matmul_kernel( - # Pointers to matrices - a_ptr, b_ptr, c_ptr, - # Matrix dimensions - M, N, K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` - # by to get the element one row down (A has M rows). - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, EVEN_K: tl.constexpr, - ACTIVATION: tl.constexpr, -): - """Kernel for computing the matmul C = A x B. - A has shape (M, K), B has shape (K, N) and C has shape (M, N) - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of C it should compute. - # This is done in a grouped ordering to promote L2 data reuse. - # See above `L2 Cache Optimizations` section for details. - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # ---------------------------------------------------------- - # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction - # and accumulate - # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers - # See above `Pointer Arithmetics` section for details - offs_k = tl.arange(0, BLOCK_SIZE_K) - if torch.version.hip is None: - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - else: - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) - a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak - b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn - - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix. - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - # Load the next block of A and B, generate a mask by checking the K dimension. - # If it is out of bounds, set it to 0. - if EVEN_K: - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) - else: - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - # We accumulate along the K dimension. - accumulator += tl.dot(a, b) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - # You can fuse arbitrary activation functions here - # while the accumulator is still in FP32! - if ACTIVATION == "leaky_relu": - accumulator = leaky_relu(accumulator) - c = accumulator.to(tl.float16) - - # ----------------------------------------------------------- - # Write back the block of the output matrix C with masks. - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) - - # We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`. @triton.jit def leaky_relu(x): @@ -284,40 +221,23 @@ def matmul(a, b, activation=""): c = torch.empty((M, N), device=a.device, dtype=a.dtype) # 1D launch kernel where each block gets its own program. - if need_split_k(M, N, K): - grid_splitK = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - META['SPLIT_K'] - ) - matmul_kernel_splitK[grid_splitK]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - ACTIVATION=activation - ) - - else: - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - ) - matmul_kernel[grid]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - ACTIVATION=activation - ) + grid_splitK = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'] + ) + matmul_kernel_splitK[grid_splitK]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + ACTIVATION=activation + ) return c def get_best_config(M, N, K): - if need_split_k(M, N, K): - best_config = matmul_kernel_splitK.get_best_config(M = M, N = N, K = K) - else: - best_config = matmul_kernel.get_best_config(M = M, N = N, K = K) + best_config = matmul_kernel_splitK.get_best_config(M = M, N = N, K = K) return best_config @@ -429,7 +349,7 @@ def main(): block_n = best_config.kwargs['BLOCK_SIZE_N'] block_k = best_config.kwargs['BLOCK_SIZE_K'] group_m = best_config.kwargs['GROUP_SIZE_M'] - split_k = best_config.kwargs['SPLIT_K'] if 'SPLIT_K' in best_config.kwargs.keys() else 1 + split_k = best_config.kwargs['SPLIT_K'] # num_warps = best_config['num_warps'] num_warps = best_config.num_warps driver = 'rocprof_gemm.py' diff --git a/scripts/amd/gemm/matmul_kernel.py b/scripts/amd/gemm/matmul_kernel.py new file mode 100644 index 000000000000..c755804e012c --- /dev/null +++ b/scripts/amd/gemm/matmul_kernel.py @@ -0,0 +1,51 @@ +import triton +import triton.language as tl + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + if GROUP_SIZE_M == 1: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + if SPLIT_K == 1: + offs_k = tl.arange(0, BLOCK_SIZE_K) + else: + offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk + c = accumulator.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=c_mask) + else: + tl.atomic_add(c_ptrs, c, mask=c_mask) diff --git a/scripts/amd/gemm/tune_gemm.py b/scripts/amd/gemm/tune_gemm.py new file mode 100644 index 000000000000..52fd47c89691 --- /dev/null +++ b/scripts/amd/gemm/tune_gemm.py @@ -0,0 +1,423 @@ +import argparse +import sys +import yaml +import os +import glob +import subprocess + +import torch +import triton +import triton.language as tl + +from matmul_kernel import matmul_kernel + +from datetime import datetime + + +def get_full_tuning_space(): + configs = [] + + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [16, 32, 64, 128, 256] + split_k_range = [1, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [1, 0] + + for block_m in block_mn_range: + for block_n in block_mn_range: + for block_k in block_k_range: + for num_warps in num_warps_range: + for group_m in group_m_range: + for split_k in split_k_range: + for num_stages in num_stage_range: + configs.append({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k, 'num_warps': num_warps, 'num_stages': num_stages}) + + return configs + + +def prune_configs(M, N, K, configs): + pruned_configs = [] + + ## TODO: improve how we deal with mfma16 vs mfma32 + ## after it becomes a tuning parameter + mfma_type = os.getenv('MFMA_TYPE') + if mfma_type == '16': + mfma = 16 + else: + mfma = 32 + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + SPLIT_K = config.get("SPLIT_K") + GROUP_M = config.get("GROUP_SIZE_M") + if BLOCK_SIZE_M < mfma or BLOCK_SIZE_N < mfma: + continue + if M <= mfma and BLOCK_SIZE_M != mfma: + continue + if N <= mfma and BLOCK_SIZE_N != mfma: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(M, N, K): + continue + # skip split_k that leads to EVEN_K = false + leap = SPLIT_K * BLOCK_SIZE_K + modv = K % leap + if modv != 0: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + ## out of shared memory resource + LDS = BLOCK_SIZE_K * BLOCK_SIZE_M + BLOCK_SIZE_K * BLOCK_SIZE_N + if LDS * 2 > 65536: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 + + +def run_bash_command(commandstring): + proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout = subprocess.PIPE) + return proc.stdout.splitlines() + + +def read_config(config): + block_m = config.get('BLOCK_SIZE_M') + block_n = config.get('BLOCK_SIZE_N') + block_k = config.get('BLOCK_SIZE_K') + group_m = config.get('GROUP_SIZE_M') + split_k = config.get('SPLIT_K') + num_warps = config.get('num_warps') + num_stages = config.get('num_stages') + return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages + + +def gen_kernel_and_configStr_from_config(M, N, K, config): + block_m, block_n, block_k, group_m, split_k, num_warps, num_stages = read_config(config) + configStr = f"M{M}_N{N}_K{K}_BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}" + + matmul_def_str = f""" +def matmul_{configStr}(a, b, c): + M, K = a.shape + K, N = b.shape + grid = triton.cdiv(M, {block_m}) * triton.cdiv(N, {block_n}), {split_k} + print(f'config: matmul_kernel_{configStr}') + matmul_kernel_{configStr}[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M = {block_m}, + BLOCK_SIZE_N = {block_n}, + BLOCK_SIZE_K = {block_k}, + GROUP_SIZE_M = {group_m}, + SPLIT_K = {split_k}, + num_warps = {num_warps}, + num_stages = {num_stages} + ) + return c + +def try_config_{configStr}(M, N, K, dtype): + a = torch.randn((M, K), device='cuda', dtype=dtype) + b = torch.randn((K, N), device='cuda', dtype=dtype) + c = torch.zeros((M, N), device=a.device, dtype=a.dtype) + try: + matmul_{configStr}(a, b, c) + except Exception: + print(f'invalid config {configStr}') +""" + return configStr, matmul_def_str + +## Open a file generated_kernelMNK.py and generate +## 1. matmul kernels of all configs +## 2. wrapper function matmul to invoke all the generated kernels +## 3. Another wraper function try_config to invoke matmul function +## 4. test_gemm to invoke +## 4.1 run try_config in parallel +## 4.2 matmul in a loop of 10 iterations +def generate_kernel(M, N, K, configs): + f_kernel = open(f'generated_kernel{M}{N}{K}.py', 'w') + + ### write imports + import_str = """import torch +import triton +import triton.language as tl +import argparse +import sys +import multiprocessing +""" + f_kernel.write(import_str + "\n") + + ### write definitions of matmul_kernel_xxx + ### and matmul_xxx and try_config + with open("matmul_kernel.py") as file: + matmul_kernel_code = file.read(); + for config in configs: + configStr, matmul_def_str = gen_kernel_and_configStr_from_config(M, N, K, config) + ## Copy the matmul_kernel with name replaced + matmul_kernel_config = matmul_kernel_code.replace("matmul_kernel", f"matmul_kernel_{configStr}") + matmul_kernel_config = matmul_kernel_config.replace("import triton.language as tl", "") + matmul_kernel_config = matmul_kernel_config.replace("import triton", "") + f_kernel.write(matmul_kernel_config + "\n\n") + f_kernel.write(matmul_def_str + "\n") + + ### write test_gemm + # pre string + test_gemm_pre_str = """def test_gemm(M, N, K, dtype, num_threads): + thread_pool = multiprocessing.Pool(processes=num_threads) + a = torch.randn((M, K), device='cuda', dtype=dtype) + b = torch.randn((K, N), device='cuda', dtype=dtype) + c = torch.zeros((M, N), device=a.device, dtype=a.dtype) + task_args = (M, N, K, dtype) +""" + f_kernel.write(test_gemm_pre_str + "\n") + + # warm up call of all matmul functions in parallel + for config in configs: + configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config) + task_str = f" thread_pool.apply_async(try_config_{configStr}, args=task_args)\n" + f_kernel.write(task_str) + + # call all matmul_xxx functions + for config in configs: + configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config) + matmul_call_str = f""" + for i in range(10): + d = matmul_{configStr}(a, b, c)""" + f_kernel.write(matmul_call_str + "\n") + # post string + f_kernel.write(" return d\n") + + ### def main and call test_gemm + def_main_str = """ +def main(): + parser = argparse.ArgumentParser( + prog="tune a specific gemm size", + allow_abbrev=False,) + parser.add_argument("-n", type=int, default=1, help='number of threads') + args = parser.parse_args() + numThreads = args.n + """ + test_gemm_call_str = f'test_gemm({M}, {N}, {K}, torch.float16, numThreads)' + f_kernel.write(def_main_str) + f_kernel.write(test_gemm_call_str + "\n\n") + f_kernel.write("""if __name__ == '__main__': + sys.exit(main())""") + f_kernel.close() + + +def tune_gemm_config(M, N, K, configs): + ## Generate kernel out of all configs + generate_kernel(M, N, K, configs) + + ## remove any compiled kernel in the cache + run_bash_command("rm -rf ~/.triton/cache") + + ## precompile the kernels in parallel + ## TODO: parameterize numThreads at this level + run_bash_command(f"python generated_kernel{M}{N}{K}.py -n 16") + + ## profile generated kernels + run_bash_command(f"rocprof --stats python generated_kernel{M}{N}{K}.py") + + ## post process results.csv to get the best config and minTime + ## TODO: process the file in parallel + minTime = 1024 * 1024 * 1024 + for config in configs: + configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config) + parse_result_cmd = f'sed -n \'/matmul_kernel_{configStr}/p\' results.csv | awk -F \',\' \'{{print $NF}}\' | tail -n1' + parsed_outputs = run_bash_command(parse_result_cmd) + if parsed_outputs: + min_us = int(parsed_outputs[0]) / 1000 + if min_us < minTime: + minTime = min_us + bestConfig = config + else: + min_us = -1 + print(f"invalid config: SIZE {M} {N} {K}: {config}") + return minTime, bestConfig + + +def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + assert b.is_contiguous(), "Matrix B must be contiguous" + M, K = a.shape + K, N = b.shape + # 1D launch kernel where each block gets its own program. + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'] + ) + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M = block_m, + BLOCK_SIZE_N = block_n, + BLOCK_SIZE_K = block_k, + GROUP_SIZE_M = group_m, + SPLIT_K = split_k, + num_warps = num_warps, + num_stages = num_stages, + ) + return c + + +def test_correctness(M, N, K, config, verbose, datatype = torch.float16): + block_m, block_n, block_k, group_m, split_k, num_warps, num_stages = read_config(config) + + torch.manual_seed(0) + a = torch.randn((M, K), device='cuda', dtype=datatype) + b = torch.randn((K, N), device='cuda', dtype=datatype) + # Allocates output. + c = torch.zeros((M, N), device=a.device, dtype=a.dtype) + triton_output = matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages) + torch_output = torch.matmul(a, b) + #print(f"triton_output={triton_output}") + #print(f"torch_output={torch_output}") + rtol = 0 if torch.version.hip is None else 1e-2 + size_str = '' + if verbose: + size_str = f'SIZE M: {M}, N: {N}, K: {K} ' + if torch.allclose(triton_output, torch_output, atol=1e-1, rtol=rtol): + print(f'{size_str}✅') + else: + print(f'{size_str}❌') + + +def get_default_tuning_result_filename(): + git_branch_name = run_bash_command("git rev-parse --abbrev-ref HEAD") + git_branch_name = git_branch_name[0].decode() + git_commit_hash = run_bash_command("git rev-parse --short HEAD") + git_commit_hash = git_commit_hash[0].decode() + + dt_string = datetime.now().strftime("%m-%d-%Y-%H:%M:%S") + defaultName = f"tuning_results_{git_branch_name}@{git_commit_hash}_{dt_string}.yaml" + return defaultName + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="tune a specific gemm size", + allow_abbrev=False, + ) + + parser.add_argument("-m", type=int, default=0) + parser.add_argument("-n", type=int, default=0) + parser.add_argument("-k", type=int, default=0) + parser.add_argument("--gemm_size_file", type=str, default="", help='yaml file to indicate matrix size') + parser.add_argument("--tuning_results_file", type=str, default=get_default_tuning_result_filename(), help='yaml file to store tuning results') + parser.add_argument("--keep", action='store_true', default=False, help='keep generated files') + parser.add_argument("--compare", action='store_true', default=False, help="Whether check result correctness") + parser.add_argument("--compare_wo_tuning", action='store_true', default=False, help="Whether check result correctness") + args = parser.parse_args() + + return args + + +def main(): + args = parse_args() + matrix_size_file = args.gemm_size_file + tuning_output_file = args.tuning_results_file + keepTmp = args.keep + + mnks = [] + ## TODO: make it more robust to get user input + if matrix_size_file == "" or not os.path.isfile(matrix_size_file): + M = args.m + N = args.n + K = args.k + mnks = [(M, N, K)] + else: + with open(matrix_size_file) as file: + matrix_sizes = yaml.safe_load(file) + for sizes in matrix_sizes: + M = sizes['M'] + N = sizes['N'] + K = sizes['K'] + mnks.append((M, N, K)) + + ## Check correctness from given configs + if args.compare_wo_tuning: + for item in matrix_sizes: + M = item['M'] + N = item['N'] + K = item['K'] + del item['M'] + del item['N'] + del item['K'] + test_correctness(M, N, K, item, True) + return + + configs_full = get_full_tuning_space() + + start_time = datetime.now() + + f_results = open(tuning_output_file, 'w') + for (M, N, K) in mnks: + ## Obtain a pruned tuning space according to gemm size + pruned_configs = prune_configs(M, N, K, configs_full) + + size_str = f'SIZE: {M} {N} {K}' + print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True) + + ## The main tuning funtion for one gemm size + minTime, bestConfig = tune_gemm_config(M, N, K, pruned_configs) + + ## post processing the numbers + perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6) + tri_tflops = perf_tflops(minTime) + if tri_tflops < 0.0001: + formatted_tflops = "{:.3e}".format(tri_tflops) + else: + formatted_tflops = "{:.2f}".format(tri_tflops) + print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ") + + bestConfig_compact_str, _ = gen_kernel_and_configStr_from_config(M, N, K, bestConfig) + print(f'best_config: {bestConfig_compact_str}', end=" ") + + ## write best config to tuning_results.yaml + sizeDict = {'M': M, 'N': N, 'K': K} + sizeDict.update(bestConfig) + f_results.write("- " + str(sizeDict) + " ") + f_results.write(f'# TFLOPS: {formatted_tflops} time(us): {minTime:.2f}\n') + + ## remove generated files if asked to + if not keepTmp: + os.remove(f"generated_kernel{M}{N}{K}.py") + for f in glob.glob("results.*"): + os.remove(f) + + ## Check correctness if asked to + if args.compare: + print("correctness: ", end=" ") + test_correctness(M, N, K, bestConfig, False) + else: + print("") + + f_results.close() + + end_time = datetime.now() + tuning_time = end_time - start_time + print(f"Tuning time (h:m:s): {tuning_time}") + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/scripts/amd/setup_rocm_libs.sh b/scripts/amd/setup_rocm_libs.sh new file mode 100755 index 000000000000..687df59e4a63 --- /dev/null +++ b/scripts/amd/setup_rocm_libs.sh @@ -0,0 +1,73 @@ +#!/usr/bin/env bash + +set -ex + +# Set ROCM_HOME if not set +if [[ -z "${ROCM_HOME}" ]]; then + export ROCM_HOME=/opt/rocm +fi + +# Check TRITON_ROCM_DIR is set +if [[ -z "${TRITON_ROCM_DIR}" ]]; then + export TRITON_ROCM_DIR=python/triton/third_party/rocm +fi + +# Create triton lib directory +mkdir -p $TRITON_ROCM_DIR/lib + +LIBTINFO_PATH="/usr/lib64/libtinfo.so.5" +LIBNUMA_PATH="/usr/lib64/libnuma.so.1" +LIBELF_PATH="/usr/lib64/libelf.so.1" + +OS_SO_PATHS=( + $LIBELF_PATH + $LIBNUMA_PATH + $LIBTINFO_PATH +) + +for lib in "${OS_SO_PATHS[@]}" +do + cp $lib $TRITON_ROCM_DIR/lib/ +done + +# Required ROCm libraries - dynamically find so numbers +ROCM_SO=( + "libhsa-runtime64.so.1" + "libamdhip64.so.5" + "libamd_comgr.so.2" + "libdrm.so.2" + "libdrm_amdgpu.so.1" +) + +# Find the SO libs dynamically +for lib in "${ROCM_SO[@]}" +do + file_path=($(find $ROCM_HOME/lib/ -name "$lib")) # First search in lib + if [[ -z $file_path ]]; then + if [ -d "$ROCM_HOME/lib64/" ]; then + file_path=($(find $ROCM_HOME/lib64/ -name "$lib")) # Then search in lib64 + fi + fi + if [[ -z $file_path ]]; then + file_path=($(find $ROCM_HOME/ -name "$lib")) # Then search in ROCM_HOME + fi + if [[ -z $file_path ]]; then + file_path=($(find /opt/ -name "$lib")) # Then search in ROCM_HOME + fi + if [[ -z $file_path ]]; then + echo "Error: Library file $lib is not found." >&2 + exit 1 + fi + + cp $file_path $TRITON_ROCM_DIR/lib + # When running locally, and not building a wheel, we need to satisfy shared objects requests that don't look for versions + LINKNAME=$(echo $lib | sed -e 's/\.so.*/.so/g') + ln -sf $lib $TRITON_ROCM_DIR/lib/$LINKNAME +done + +# Copy Include Files +cp -r $ROCM_HOME/include $TRITON_ROCM_DIR/ + +# Copy linker +mkdir -p $TRITON_ROCM_DIR/llvm/bin +cp $ROCM_HOME/llvm/bin/ld.lld $TRITON_ROCM_DIR/llvm/bin/ diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index 6a5b2e5f5375..75740d929da2 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -75,6 +75,20 @@ tt.func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { tt.return } +// CHECK-LABEL: insert_slice_async_v2 +tt.func @insert_slice_async_v2(%A : !tt.ptr, %i1 : i1) { + %mbar = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : !tt.ptr + %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> + %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> + %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + // CHECK: %cst_0 -> %cst_0 + %tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> + %index = arith.constant 0 : i32 + // CHECK: %3 -> %cst_0 + %a = triton_nvidia_gpu.insert_slice_async_v2 %a_ptr, %tensor, %index, %mbar, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : tensor<16x16x!tt.ptr, #AL>, tensor<1x16x16xf16, #A_SHARED>, i32, !tt.ptr, tensor<16x16xi1, #AL>, tensor<16x16xf16, #AL> -> tensor<1x16x16xf16, #A_SHARED> + tt.return +} + // CHECK-LABEL: insert_slice tt.func @insert_slice(%A : !tt.ptr, %i1 : i1) { %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> @@ -99,6 +113,16 @@ tt.func @extract_slice(%A : !tt.ptr) { tt.return } +// CHECK-LABEL: extract_m_barrier +tt.func @extract_m_barrier() { + // CHECK: %0 -> %0 + %mbar = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : tensor<2xi64, #A_SHARED> + %c0 = arith.constant 0 : i32 + // CHECK: %1 -> %0 + %mbar0 = triton_nvidia_gpu.extract_mbarrier %mbar[%c0] : tensor<2xi64, #A_SHARED>, i32 -> !tt.ptr + tt.return +} + // CHECK-LABEL: if_cat tt.func @if_cat(%i1 : i1) { // CHECK: %cst -> %cst diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 6db70b2ac591..93d80448c998 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -1,14 +1,14 @@ // RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation 2>&1 | FileCheck %s -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> -#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}> +#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> module attributes {"triton_gpu.num-warps" = 4 : i32} { @@ -28,10 +28,10 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - // CHECK: scratch offset = 0, size = 4608 + // CHECK: offset = 0, size = 4608 %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> - // CHECK-NEXT: scratch offset = 0, size = 4224 + // CHECK-NEXT: offset = 0, size = 4224 %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT> %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> @@ -56,17 +56,17 @@ tt.func @reusable(%A : !tt.ptr) { %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<32x128x!tt.ptr, #AL> %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - // CHECK-NEXT: scratch offset = 0, size = 4608 + // CHECK-NEXT: offset = 0, size = 4608 %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT> %a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> - // CHECK-NEXT: scratch offset = 0, size = 1152 + // CHECK-NEXT: offset = 0, size = 1152 %a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT> %a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - // CHECK-NEXT: scratch offset = 0, size = 4608 + // CHECK-NEXT: offset = 0, size = 4608 %a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT> %c = tt.dot %a1, %a2, %c_init {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> - // CHECK-NEXT: scratch offset = 0, size = 1152 + // CHECK-NEXT: offset = 0, size = 1152 %a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT> %c1 = tt.dot %a3, %a4, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> tt.return @@ -83,11 +83,11 @@ tt.func @preallocate(%A : !tt.ptr) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1024, size = 512 %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 1536, size = 512 + // CHECK-NEXT: offset = 2048, size = 512 %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 2048, size = 1024 - %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> // CHECK-NEXT: offset = 3072, size = 1024 + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + // CHECK-NEXT: offset = 4096, size = 1024 %b = tt.cat %cst0, %cst2 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> // CHECK-NEXT: offset = 0, size = 1024 %c = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> @@ -118,12 +118,12 @@ tt.func @unused(%A : !tt.ptr) { %cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A_SHARED> // CHECK-NEXT: offset = 0, size = 512 %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 512, size = 512 + // CHECK-NEXT: offset = 1024, size = 512 %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 1024, size = 1024 + // CHECK-NEXT: offset = 2048, size = 1024 %a = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> tt.return - // CHECK: size = 2048 + // CHECK: size = 3072 } // cst0 is alive through the entire function, it cannot be released before the end of the function @@ -131,28 +131,28 @@ tt.func @unused(%A : !tt.ptr) { tt.func @longlive(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 512, size = 512 - %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1024, size = 512 + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + // CHECK-NEXT: offset = 2048, size = 512 %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 1536, size = 1024 + // CHECK-NEXT: offset = 3072, size = 1024 %a = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 512, size = 512 - %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1024, size = 512 + %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + // CHECK-NEXT: offset = 2048, size = 512 %cst4 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 1536, size = 1024 + // CHECK-NEXT: offset = 3072, size = 1024 %b = tt.cat %cst3, %cst4 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 1536, size = 512 + // CHECK-NEXT: offset = 3072, size = 512 %cst5 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 1536, size = 512 + // CHECK-NEXT: offset = 3072, size = 512 %cst6 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 1536, size = 1024 + // CHECK-NEXT: offset = 3072, size = 1024 %c = tt.cat %cst3, %cst4 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 512, size = 1024 + // CHECK-NEXT: offset = 1024, size = 1024 %d = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> tt.return - // CHECK-NEXT: size = 2560 + // CHECK-NEXT: size = 4096 } // This example triggers graph coloring with > 1 colors. @@ -160,12 +160,12 @@ tt.func @longlive(%A : !tt.ptr) { tt.func @multi_color(%A : !tt.ptr) { // CHECK: offset = 0, size = 64 %cst = arith.constant dense<0.000000e+00> : tensor<4x8xf16, #A_SHARED> - // CHECK-NEXT: offset = 1216, size = 32 + // CHECK-NEXT: offset = 1536, size = 32 %cst_0 = arith.constant dense<0.000000e+00> : tensor<4x4xf16, #A_SHARED> - // CHECK-NEXT: offset = 1248, size = 128 + // CHECK-NEXT: offset = 1664, size = 128 %cst_1 = arith.constant dense<0.000000e+00> : tensor<16x4xf16, #A_SHARED> %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> - // CHECK-NEXT: scratch offset = 64, size = 1152 + // CHECK-NEXT: scratch offset = 128, size = 1152 %0 = triton_gpu.convert_layout %cst_2 : (tensor<16x32xf16, #AL>) -> tensor<16x32xf16, #AL> %1 = triton_gpu.convert_layout %cst : (tensor<4x8xf16, #A_SHARED>) -> tensor<4x8xf16, #AL> // CHECK-NEXT: offset = 0, size = 128 @@ -179,16 +179,16 @@ tt.func @multi_color(%A : !tt.ptr) { %cst_5 = arith.constant dense<0.000000e+00> : tensor<4x8xf16, #A_SHARED> %4 = triton_gpu.convert_layout %cst_5 : (tensor<4x8xf16, #A_SHARED>) -> tensor<4x8xf16, #AL> %5 = triton_gpu.convert_layout %cst_5 : (tensor<4x8xf16, #A_SHARED>) -> tensor<4x8xf16, #AL> - // CHECK-NEXT: offset = 256, size = 512 + // CHECK-NEXT: offset = 1024, size = 512 %cst_6 = arith.constant dense<0.000000e+00> : tensor<8x32xf16, #A_SHARED> - // CHECK-NEXT: offset = 2528, size = 128 + // CHECK-NEXT: offset = 3104, size = 128 %cst_7 = arith.constant dense<0.000000e+00> : tensor<2x32xf16, #A_SHARED> %6 = triton_gpu.convert_layout %cst_0 : (tensor<4x4xf16, #A_SHARED>) -> tensor<4x4xf16, #AL> - // CHECK-NEXT: offset = 256, size = 512 + // CHECK-NEXT: offset = 1024, size = 512 %cst_8 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 256, size = 32 %cst_9 = arith.constant dense<0.000000e+00> : tensor<4x4xf16, #A_SHARED> - // CHECK-NEXT: offset = 256, size = 512 + // CHECK-NEXT: offset = 1024, size = 512 %cst_10 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> %7 = triton_gpu.convert_layout %cst_1 : (tensor<16x4xf16, #A_SHARED>) -> tensor<16x4xf16, #AL> %8 = triton_gpu.convert_layout %cst_4 : (tensor<4x32xf16, #A_SHARED>) -> tensor<4x32xf16, #AL> @@ -198,7 +198,7 @@ tt.func @multi_color(%A : !tt.ptr) { %10 = triton_gpu.convert_layout %cst_7 : (tensor<2x32xf16, #A_SHARED>) -> tensor<2x32xf16, #AL> %cst_12 = arith.constant dense<0.000000e+00> : tensor<4x16xf16, #AL> %cst_13 = arith.constant dense<0.000000e+00> : tensor<8x32xf16, #AL> - // CHECK-NEXT: size = 2656 + // CHECK-NEXT: size = 3232 tt.return } @@ -207,15 +207,15 @@ tt.func @multi_color(%A : !tt.ptr) { tt.func @multi_color_multi_rounds(%arg0: !tt.ptr) { // CHECK: offset = 0, size = 32 %cst = arith.constant dense<0.000000e+00> : tensor<4x4xf16, #A_SHARED> - // CHECK-NEXT: offset = 1184, size = 128 + // CHECK-NEXT: offset = 1280, size = 128 %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x4xf16, #A_SHARED> - // CHECK-NEXT: offset = 1312, size = 8192 + // CHECK-NEXT: offset = 2048, size = 8192 %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x4xf16, #A_SHARED> %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> - // CHECK-NEXT: scratch offset = 32, size = 1152 + // CHECK-NEXT: scratch offset = 128, size = 1152 %0 = triton_gpu.convert_layout %cst_2 : (tensor<16x32xf16, #AL>) -> tensor<16x32xf16, #AL> %1 = triton_gpu.convert_layout %cst : (tensor<4x4xf16, #A_SHARED>) -> tensor<4x4xf16, #AL> - // CHECK-NEXT: offset = 11968, size = 128 + // CHECK-NEXT: offset = 1152, size = 128 %cst_3 = arith.constant dense<0.000000e+00> : tensor<2x32xf16, #A_SHARED> %2 = triton_gpu.convert_layout %cst : (tensor<4x4xf16, #A_SHARED>) -> tensor<4x4xf16, #AL> // CHECK-NEXT: offset = 0, size = 512 @@ -225,7 +225,7 @@ tt.func @multi_color_multi_rounds(%arg0: !tt.ptr) { // CHECK-NEXT: scratch offset = 0, size = 1152 %5 = triton_gpu.convert_layout %cst_2 : (tensor<16x32xf16, #AL>) -> tensor<16x32xf16, #AL> %6 = triton_gpu.convert_layout %cst_3 : (tensor<2x32xf16, #A_SHARED>) -> tensor<2x32xf16, #AL> - // CHECK-NEXT: size = 12096 + // CHECK-NEXT: size = 10240 tt.return } @@ -241,6 +241,27 @@ tt.func @alloc(%A : !tt.ptr) { // CHECK-NEXT: size = 512 } +// mbarrier's shared memory cannot be reused +// CHECK-LABEL: alloc_m_barrier +tt.func @alloc_m_barrier() { + // CHECK: offset = 0, size = 16 + %mbar0 = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : tensor<2xi64, #A_SHARED> + // CHECK-NEXT: offset = 16, size = 16 + %mbar1 = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : tensor<2xi64, #A_SHARED> + // CHECK-NEXT: size = 32 + tt.return +} + +// CHECK-LABEL: alloc_m_barrier_scalar +tt.func @alloc_m_barrier_scalar() { + // CHECK: offset = 0, size = 8 + %mbar0 = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : !tt.ptr + // CHECK-NEXT: offset = 8, size = 8 + %mbar1 = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : !tt.ptr + // CHECK-NEXT: size = 16 + tt.return +} + // CHECK-LABEL: scratch tt.func @scratch() { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> @@ -291,22 +312,22 @@ tt.func @extract_slice(%A : !tt.ptr) { tt.func @if(%i1 : i1) { // CHECK: offset = 0, size = 512 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 512, size = 512 + // CHECK-NEXT: offset = 1024, size = 512 %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> scf.if %i1 { - // CHECK-NEXT: offset = 1024, size = 1024 + // CHECK-NEXT: offset = 2048, size = 1024 %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 1024, size = 1024 + // CHECK-NEXT: offset = 2048, size = 1024 %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> } // CHECK-NEXT: offset = 0, size = 512 %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 512, size = 512 + // CHECK-NEXT: offset = 1024, size = 512 %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 1024, size = 1024 + // CHECK-NEXT: offset = 2048, size = 1024 %a = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> tt.return - // CHECK-NEXT: size = 2048 + // CHECK-NEXT: size = 3072 } // B0 -> (B1) -> (B2) -> B0 @@ -315,25 +336,25 @@ tt.func @if(%i1 : i1) { tt.func @if_else(%i1 : i1) { // CHECK: offset = 0, size = 512 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 512, size = 512 + // CHECK-NEXT: offset = 1024, size = 512 %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> scf.if %i1 { - // CHECK-NEXT: offset = 1024, size = 1024 + // CHECK-NEXT: offset = 2048, size = 1024 %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 1024, size = 1024 + // CHECK-NEXT: offset = 2048, size = 1024 %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> } else { - // CHECK-NEXT: offset = 1024, size = 512 + // CHECK-NEXT: offset = 2048, size = 512 %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 1536, size = 512 + // CHECK-NEXT: offset = 3072, size = 512 %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK-NEXT: offset = 2048, size = 1024 + // CHECK-NEXT: offset = 4096, size = 1024 %a = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> } - // CHECK-NEXT: offset = 1024, size = 1024 + // CHECK-NEXT: offset = 2048, size = 1024 %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> tt.return - // CHECK-NEXT: size = 3072 + // CHECK-NEXT: size = 5120 } // Block arguments and yields are memory aliases that do not trigger a new diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 621dd10e2cdd..63b4ef5d2ca1 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -1,14 +1,14 @@ // RUN: triton-opt %s -split-input-file --mlir-disable-threading --convert-scf-to-cf -test-print-membar 2>&1 | FileCheck %s -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> -#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}> +#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> module attributes {"triton_gpu.num-warps" = 4 : i32} { @@ -447,58 +447,58 @@ tt.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-LABEL: cf_if tt.func @cf_if(%i1 : i1) { - %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> - %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: tt.cat - %0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + %0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> cf.br ^bb2 ^bb2: // 2 preds: ^bb0, ^bb1 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.convert_layout - %1 = triton_gpu.convert_layout %cst : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<16x16xf16, #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>> + %1 = triton_gpu.convert_layout %cst : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> tt.return } tt.func @cf_if_else(%i1 : i1) { - %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> - %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: tt.cat - %0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> - cf.br ^bb3(%0 : tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) + %0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + cf.br ^bb3(%0 : tensor<32x16xf16, #A_SHARED>) ^bb2: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: tt.cat - %1 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> - cf.br ^bb3(%1 : tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -^bb3(%2: tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>): // 2 preds: ^bb1, ^bb2 + %1 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + cf.br ^bb3(%1 : tensor<32x16xf16, #A_SHARED>) +^bb3(%2: tensor<32x16xf16, #A_SHARED>): // 2 preds: ^bb1, ^bb2 cf.br ^bb4 ^bb4: // pred: ^bb3 - %3 = triton_gpu.convert_layout %cst : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<16x16xf16, #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>> + %3 = triton_gpu.convert_layout %cst : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> // CHECK: gpu.barrier // CHECK-NEXT: tt.cat - %4 = tt.cat %2, %2 {axis = 0 : i64} : (tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<64x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + %4 = tt.cat %2, %2 {axis = 0 : i64} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED> tt.return } tt.func @cf_if_else_return(%i1 : i1) { - %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> - %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: tt.cat - %0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + %0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> tt.return ^bb2: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: tt.cat - %1 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + %1 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> tt.return } diff --git a/test/Conversion/AMDGPU/load_store.mlir b/test/Conversion/AMDGPU/load_store.mlir index f6c3ccfa0712..cc42cff24d97 100644 --- a/test/Conversion/AMDGPU/load_store.mlir +++ b/test/Conversion/AMDGPU/load_store.mlir @@ -1,7 +1,7 @@ -// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck --check-prefixes=CHECK,GCN %s +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=target=rocdl | FileCheck --check-prefixes=CHECK,GCN %s // Check load instruction doesn't generate incorrect bitcast. -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: @test_float16_bitcast tt.func public @test_float16_bitcast(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { %true = arith.constant true diff --git a/test/Conversion/invalid.mlir b/test/Conversion/invalid.mlir index e7e356c8b45d..81b86650291f 100644 --- a/test/Conversion/invalid.mlir +++ b/test/Conversion/invalid.mlir @@ -5,7 +5,7 @@ #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> module attributes {"triton_gpu.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf32, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { - // expected-error@+1 {{element types of operands A and B must match}} + // expected-error@+1 {{element types of operands A and B must have same bit width}} %D = tt.dot %A, %B, %C {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf32, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> tt.return diff --git a/test/Conversion/minimize_alloc.mlir b/test/Conversion/minimize_alloc.mlir new file mode 100644 index 000000000000..8a33aa8bfa6c --- /dev/null +++ b/test/Conversion/minimize_alloc.mlir @@ -0,0 +1,117 @@ +// RUN: triton-opt --convert-triton-gpu-to-llvm=target=rocdl %s | FileCheck %s + +// CHECK: module attributes {{.*}}, triton_gpu.shared = 9216 : i32 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @matmul_kernel_0d1d2d3d4d5d6d7c8d9c10d11c(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mfma> + %cst_0 = arith.constant dense<32> : tensor<64x32xi32, #blocked> + %c31_i32 = arith.constant 31 : i32 + %c63_i32 = arith.constant 63 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c4_i32 = arith.constant 4 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c63_i32 : i32 + %2 = arith.divsi %1, %c64_i32 : i32 + %3 = arith.addi %arg4, %c63_i32 : i32 + %4 = arith.divsi %3, %c64_i32 : i32 + %5 = arith.muli %4, %c4_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c4_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = "triton_gpu.cmpi"(%8, %c4_i32) <{predicate = 2 : i64}> : (i32, i32) -> i1 + %10 = arith.select %9, %8, %c4_i32 : i32 + %11 = arith.remsi %0, %10 : i32 + %12 = arith.addi %7, %11 : i32 + %13 = arith.remsi %0, %5 : i32 + %14 = arith.divsi %13, %10 : i32 + %15 = arith.muli %12, %c64_i32 : i32 + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %19 = tt.splat %15 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %20 = tt.splat %15 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %21 = arith.addi %19, %16 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %22 = arith.addi %20, %18 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %23 = arith.muli %14, %c64_i32 : i32 + %24 = tt.splat %23 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %25 = arith.addi %24, %17 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %26 = tt.expand_dims %21 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi32, #blocked> + %27 = tt.expand_dims %22 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1> + %28 = tt.splat %arg6 : (i32) -> tensor<64x1xi32, #blocked> + %29 = arith.muli %26, %28 : tensor<64x1xi32, #blocked> + %30 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked> + %31 = tt.addptr %30, %29 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %32 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %33 = tt.expand_dims %32 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x32xi32, #blocked> + %34 = tt.broadcast %31 : (tensor<64x1x!tt.ptr, #blocked>) -> tensor<64x32x!tt.ptr, #blocked> + %35 = tt.broadcast %33 : (tensor<1x32xi32, #blocked>) -> tensor<64x32xi32, #blocked> + %36 = tt.addptr %34, %35 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi32, #blocked> + %37 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %38 = tt.expand_dims %37 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<32x1xi32, #blocked1> + %39 = tt.splat %arg7 : (i32) -> tensor<32x1xi32, #blocked1> + %40 = arith.muli %38, %39 : tensor<32x1xi32, #blocked1> + %41 = tt.splat %arg1 : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked1> + %42 = tt.addptr %41, %40 : tensor<32x1x!tt.ptr, #blocked1>, tensor<32x1xi32, #blocked1> + %43 = tt.expand_dims %25 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> + %44 = tt.broadcast %42 : (tensor<32x1x!tt.ptr, #blocked1>) -> tensor<32x64x!tt.ptr, #blocked1> + %45 = tt.broadcast %43 : (tensor<1x64xi32, #blocked1>) -> tensor<32x64xi32, #blocked1> + %46 = tt.addptr %44, %45 : tensor<32x64x!tt.ptr, #blocked1>, tensor<32x64xi32, #blocked1> + %47 = arith.addi %arg5, %c31_i32 : i32 + %48 = arith.divsi %47, %c32_i32 : i32 + %49 = arith.muli %arg7, %c32_i32 : i32 + %50 = tt.splat %49 : (i32) -> tensor<32x64xi32, #blocked1> + %51 = tt.load %36 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16, #blocked> + %52 = triton_gpu.convert_layout %51 : (tensor<64x32xf16, #blocked>) -> tensor<64x32xf16, #shared> + %53 = tt.load %46 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16, #blocked1> + %54 = triton_gpu.convert_layout %53 : (tensor<32x64xf16, #blocked1>) -> tensor<32x64xf16, #shared1> + %55 = tt.addptr %36, %cst_0 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi32, #blocked> + %56 = tt.addptr %46, %50 : tensor<32x64x!tt.ptr, #blocked1>, tensor<32x64xi32, #blocked1> + %57 = arith.subi %48, %c1_i32 : i32 + cf.br ^bb1(%c0_i32, %cst, %52, %54, %55, %56 : i32, tensor<64x64xf32, #mfma>, tensor<64x32xf16, #shared>, tensor<32x64xf16, #shared1>, tensor<64x32x!tt.ptr, #blocked>, tensor<32x64x!tt.ptr, #blocked1>) + ^bb1(%58: i32, %59: tensor<64x64xf32, #mfma>, %60: tensor<64x32xf16, #shared>, %61: tensor<32x64xf16, #shared1>, %62: tensor<64x32x!tt.ptr, #blocked>, %63: tensor<32x64x!tt.ptr, #blocked1>): // 2 preds: ^bb0, ^bb2 + %64 = arith.cmpi slt, %58, %57 : i32 + cf.cond_br %64, ^bb2, ^bb3 + ^bb2: // pred: ^bb1 + %65 = tt.load %62 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16, #blocked> + %66 = tt.load %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16, #blocked1> + %67 = triton_gpu.convert_layout %60 : (tensor<64x32xf16, #shared>) -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> + %68 = triton_gpu.convert_layout %61 : (tensor<32x64xf16, #shared1>) -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> + %69 = tt.dot %67, %68, %59 {allowTF32 = true} : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> -> tensor<64x64xf32, #mfma> + %70 = tt.addptr %62, %cst_0 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi32, #blocked> + %71 = tt.addptr %63, %50 : tensor<32x64x!tt.ptr, #blocked1>, tensor<32x64xi32, #blocked1> + %72 = triton_gpu.convert_layout %65 : (tensor<64x32xf16, #blocked>) -> tensor<64x32xf16, #shared> + %73 = triton_gpu.convert_layout %66 : (tensor<32x64xf16, #blocked1>) -> tensor<32x64xf16, #shared1> + %74 = arith.addi %58, %c1_i32 : i32 + cf.br ^bb1(%74, %69, %72, %73, %70, %71 : i32, tensor<64x64xf32, #mfma>, tensor<64x32xf16, #shared>, tensor<32x64xf16, #shared1>, tensor<64x32x!tt.ptr, #blocked>, tensor<32x64x!tt.ptr, #blocked1>) + ^bb3: // pred: ^bb1 + %75 = triton_gpu.convert_layout %60 : (tensor<64x32xf16, #shared>) -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> + %76 = triton_gpu.convert_layout %61 : (tensor<32x64xf16, #shared1>) -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> + %77 = tt.dot %75, %76, %59 {allowTF32 = true} : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 8}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 8}>> -> tensor<64x64xf32, #mfma> + %78 = arith.truncf %77 : tensor<64x64xf32, #mfma> to tensor<64x64xf16, #mfma> + %79 = tt.splat %arg8 : (i32) -> tensor<64x1xi32, #blocked1> + %80 = arith.muli %79, %27 : tensor<64x1xi32, #blocked1> + %81 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> + %82 = tt.addptr %81, %80 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %83 = tt.broadcast %82 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> + %84 = tt.broadcast %43 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1> + %85 = tt.addptr %83, %84 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %86 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked1> + %87 = "triton_gpu.cmpi"(%27, %86) <{predicate = 2 : i64}> : (tensor<64x1xi32, #blocked1>, tensor<64x1xi32, #blocked1>) -> tensor<64x1xi1, #blocked1> + %88 = tt.splat %arg4 : (i32) -> tensor<1x64xi32, #blocked1> + %89 = "triton_gpu.cmpi"(%43, %88) <{predicate = 2 : i64}> : (tensor<1x64xi32, #blocked1>, tensor<1x64xi32, #blocked1>) -> tensor<1x64xi1, #blocked1> + %90 = tt.broadcast %87 : (tensor<64x1xi1, #blocked1>) -> tensor<64x64xi1, #blocked1> + %91 = tt.broadcast %89 : (tensor<1x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1> + %92 = arith.andi %90, %91 : tensor<64x64xi1, #blocked1> + %93 = triton_gpu.convert_layout %78 : (tensor<64x64xf16, #mfma>) -> tensor<64x64xf16, #blocked1> + tt.store %85, %93, %92 {cache = 1 : i32, evict = 1 : i32} : tensor<64x64xf16, #blocked1> + tt.return + } +} diff --git a/test/Conversion/triton_ops.mlir b/test/Conversion/triton_ops.mlir index 4acfbb1e8166..bb4cba09645d 100644 --- a/test/Conversion/triton_ops.mlir +++ b/test/Conversion/triton_ops.mlir @@ -2,9 +2,9 @@ tt.func @cast_ops(%scalar_ptr: !tt.ptr, %scalar_f32: f32, %scalar_i64: i64) { // scalar -> scalar - // CHECK: i64 -> !tt.ptr + // CHECK: i64 -> !tt.ptr %0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr - // CHECK: !tt.ptr -> i64 + // CHECK: !tt.ptr -> i64 %1 = tt.ptr_to_int %scalar_ptr : !tt.ptr -> i64 // CHECK: f32 to f16 %2 = arith.truncf %scalar_f32 : f32 to f16 @@ -14,9 +14,9 @@ tt.func @cast_ops(%scalar_ptr: !tt.ptr, %scalar_f32: f32, %scalar_i64: i64) %tensor_f32_0d = tt.splat %scalar_f32 : (f32) -> tensor %tensor_i64_0d = tt.splat %scalar_i64 : (i64) -> tensor - // CHECK: tensor -> tensor> + // CHECK: tensor -> tensor> %3 = tt.int_to_ptr %tensor_i64_0d : tensor -> tensor> - // CHECK: tensor> -> tensor + // CHECK: tensor> -> tensor %4 = tt.ptr_to_int %tensor_ptr_0d : tensor> -> tensor // CHECK: tensor to tensor %5 = arith.truncf %tensor_f32_0d : tensor to tensor @@ -26,9 +26,9 @@ tt.func @cast_ops(%scalar_ptr: !tt.ptr, %scalar_f32: f32, %scalar_i64: i64) %tensor_f32_1d = tt.splat %scalar_f32 : (f32) -> tensor<16xf32> %tensor_i64_1d = tt.splat %scalar_i64 : (i64) -> tensor<16xi64> - // CHECK: tensor<16xi64> -> tensor<16x!tt.ptr> + // CHECK: tensor<16xi64> -> tensor<16x!tt.ptr> %6 = tt.int_to_ptr %tensor_i64_1d : tensor<16xi64> -> tensor<16x!tt.ptr> - // CHECK: tensor<16x!tt.ptr> -> tensor<16xi64> + // CHECK: tensor<16x!tt.ptr> -> tensor<16xi64> %7 = tt.ptr_to_int %tensor_ptr_1d : tensor<16x!tt.ptr> -> tensor<16xi64> // CHECK: tensor<16xf32> to tensor<16xf16> %8 = arith.truncf %tensor_f32_1d : tensor<16xf32> to tensor<16xf16> @@ -37,19 +37,19 @@ tt.func @cast_ops(%scalar_ptr: !tt.ptr, %scalar_f32: f32, %scalar_i64: i64) tt.func @addptr_ops(%scalar_ptr: !tt.ptr, %scalar_i32: i32) { // scalar -> scalar - // CHECK: !tt.ptr + // CHECK: !tt.ptr %0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr, i32 // 0D tensor -> 0D tensor %tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr) -> tensor> %tensor_i32_0d = tt.splat %scalar_i32 : (i32) -> tensor - // CHECK: tensor> + // CHECK: tensor> %1 = tt.addptr %tensor_ptr_0d, %tensor_i32_0d : tensor>, tensor // 1D tensor -> 1D tensor %tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr) -> tensor<16x!tt.ptr> %tensor_i32_1d = tt.splat %scalar_i32 : (i32) -> tensor<16xi32> - // CHECK: tensor<16x!tt.ptr> + // CHECK: tensor<16x!tt.ptr> %2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr>, tensor<16xi32> tt.return } @@ -201,5 +201,12 @@ tt.func @scan_op(%ptr: tensor<1x2x4x!tt.ptr>, %v : tensor<1x2x4xf32>) { }) : (tensor<1x2x4xf32>) -> tensor<1x2x4xf32> tt.store %ptr, %a : tensor<1x2x4xf32> tt.return +} +// CHECK-LABEL: inline_asm +// CHECK: tt.elementwise_inline_asm "shl.b32 $0, $0, 3;" +tt.func @inline_asm(%0: tensor<512xi8>) { + %1 = tt.elementwise_inline_asm "shl.b32 $0, $0, 3;" + {constraints = "=r,r", packed_element = 4 : i32, pure = true} %0 : tensor<512xi8> -> tensor<512xi8> + tt.return } diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index d5e6b32c48f3..cc3b4dd5bedd 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -1,7 +1,7 @@ // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s tt.func @ops() { - // CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {{.*}} + // CHECK: module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {{.*}} %a = arith.constant dense<1.00e+00> : tensor<128x32xf16> %b = arith.constant dense<2.00e+00> : tensor<32x128xf16> %c = arith.constant dense<3.00e+00> : tensor<128x128xf32> @@ -33,10 +33,10 @@ tt.func @load_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { tt.func @reduce_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { // Test if the total number of threadsPerWarp is 64 // Test if the total number of warps is 2 - // CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 16], warpsPerCTA = [1, 2], order = [0, 1]}> - // CHECK: #[[blocked1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 2], order = [0, 1]}> - // CHECK: #[[blocked2:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 2], order = [0, 1]}> - // CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {{.*}} + // CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 16], warpsPerCTA = [1, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> + // CHECK: #[[blocked1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> + // CHECK: #[[blocked2:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> + // CHECK: module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {{.*}} %c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32> %c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32> %c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32> diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 7460b1b7987f..f937d5b78b7a 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1,9 +1,9 @@ -// RUN: not triton-opt %s -split-input-file --convert-triton-gpu-to-llvm --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null | FileCheck --check-prefixes=CHECK,GCN %s +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=rocdl" 2>/dev/null | FileCheck --check-prefixes=CHECK,GCN %s -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr) // Here the 128 comes from the 4 in module attribute multiples 32 - // PTX: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = [128 : i32]} {{.*}} + // PTX: nvvm.kernel = 1 : ui1, nvvm.maxntid = [128 : i32] tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr) { // CHECK: llvm.return tt.return @@ -12,8 +12,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_load tt.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { // PTX: llvm.inline_asm @@ -33,8 +33,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: vectorized_load tt.func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { // GCN-NOT: llvm.inline_asm @@ -58,8 +58,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: vectorized_load_f16 tt.func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) { // GCN-NOT: llvm.inline_asm @@ -107,8 +107,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- // TODO: masked load with vectorization is pending on TODO -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: masked_load_const_other tt.func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> @@ -120,8 +120,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- // TODO: masked load with vectorization is pending on TODO -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: masked_load_const_other_vec tt.func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> @@ -132,8 +132,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: store_with_cache_attr tt.func @store_with_cache_attr(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { // PTX: llvm.inline_asm @@ -147,8 +147,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> -module attributes {"triton_gpu.num-warps" = 2 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { // CHECK-LABEL: global_load_store_no_vec tt.func @global_load_store_no_vec(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr {tt.divisibility = 4 : i32}, %arg2: !tt.ptr {tt.divisibility = 4 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -239,8 +239,8 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> -module attributes {"triton_gpu.num-warps" = 2 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { // CHECK-LABEL: global_load_store_vec4 tt.func @global_load_store_vec4(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -316,9 +316,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { // ----- // This test verifies the vectorization of Load and Store Ops. -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> // Note, the %n_elements doesn't have a "tt.divisibility" hint, so Triton assumes it's divisibility is 1, this should effect the mask's alignment and further restrict the load/store ops' vector width to be 1. -module attributes {"triton_gpu.num-warps" = 2 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { tt.func @vecadd_masked_vec1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32) { %c64_i32 = arith.constant 64 : i32 %0 = tt.get_program_id x : i32 @@ -350,8 +350,8 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: global_load_store_vec2 tt.func @global_load_store_vec2(%arg0: !tt.ptr {tt.divisibility = 8 : i32}, %arg1: !tt.ptr {tt.divisibility = 8 : i32}, %arg2: !tt.ptr {tt.divisibility = 8 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -471,8 +471,129 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: global_load_store_vec2 + tt.func @global_load_store_vec2(%arg0: !tt.ptr {tt.divisibility = 8 : i32}, %arg1: !tt.ptr {tt.divisibility = 8 : i32}, %arg2: !tt.ptr {tt.divisibility = 8 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> + %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + + // Load 8 elements from A with four vectorized load instruction + // GCN-NOT: llvm.inline_asm + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> + // GCN: llvm.insertvalue {{.*}}[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.insertvalue {{.*}}[1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.insertvalue {{.*}}[2] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.insertvalue {{.*}}[3] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.insertvalue {{.*}}[4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.insertvalue {{.*}}[5] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.insertvalue {{.*}}[6] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.insertvalue {{.*}}[7] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // PTX: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // PTX: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // PTX: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // PTX: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + + // Load 8 elements from B with four vectorized load instruction + // GCN-NOT: llvm.inline_asm + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> + // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr + // GCN: llvm.load {{.*}} : !llvm.ptr + // GCN: llvm.bitcast {{.*}} : i32 to vector<1xf32> + // GCN: llvm.insertvalue {{.*}}[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.insertvalue {{.*}}[1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.insertvalue {{.*}}[2] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.insertvalue {{.*}}[3] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.insertvalue {{.*}}[4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.insertvalue {{.*}}[5] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.insertvalue {{.*}}[6] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.insertvalue {{.*}}[7] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // PTX: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // PTX: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // PTX: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // PTX: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + + %9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> + %10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + + // Store 8 elements to global with four vectorized store instruction + // GCN-NOT: llvm.inline_asm + // GCN: llvm.extractvalue {{.*}}[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.extractvalue {{.*}}[1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.extractvalue {{.*}}[2] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.extractvalue {{.*}}[3] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.extractvalue {{.*}}[4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.extractvalue {{.*}}[5] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.extractvalue {{.*}}[6] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN: llvm.extractvalue {{.*}}[7] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // GCN-COUNT-8: llvm.store {{.*}} : !llvm.ptr + // PTX: @${{.*}} st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} }; + // PTX: @${{.*}} st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} }; + // PTX: @${{.*}} st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} }; + // PTX: @${{.*}} st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} }; + tt.store %13, %11 : tensor<256xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: global_load_store_vec8 tt.func @global_load_store_vec8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -589,9 +710,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_view_broadcast tt.func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) { // CHECK: llvm.mlir.undef @@ -614,8 +735,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_make_range tt.func @basic_make_range() { // PTX: nvvm.read.ptx.sreg.tid.x @@ -629,8 +750,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_addf tt.func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) { // CHECK: llvm.fadd @@ -642,8 +763,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_addi tt.func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { // CHECK: llvm.add @@ -655,10 +776,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_program_id tt.func @basic_program_id() { - // PTX: nvvm.read.ptx.sreg.ctaid.x : i32 + // PTX: llvm.inline_asm asm_dialect = att operand_attrs = [] "mov.u32 $0, %ctaid.x;", "=r" : () -> i32 %0 = tt.get_program_id x : i32 tt.return } @@ -666,8 +787,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_addptr tt.func @basic_addptr(%arg0 : tensor<256x!tt.ptr,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { // CHECK: llvm.getelementptr @@ -679,8 +800,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: basic_alloc_tensor tt.func @basic_alloc_tensor() { @@ -696,8 +817,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: basic_extract_slice tt.func @basic_extract_slice() { @@ -731,12 +852,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -module attributes {"triton_gpu.num-warps" = 4 : i32} { - // PTX-LABEL: basic_async_wait - // This test is disabled for GCN target, because it is PTX specific - // GCN-NOT: basic_async_wait +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_async_wait tt.func @basic_async_wait() { // PTX: cp.async.wait_group 0x4 + // GCN-NOT: cp.async.wait_group triton_gpu.async_wait {num = 4: i32} tt.return } @@ -744,15 +864,15 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}> -#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> -#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}> +#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}> #slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}> -#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_fallback tt.func @basic_insert_slice_async_fallback(%arg0: !tt.ptr {tt.divisibility = 1 : i32}) { %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> @@ -783,15 +903,15 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}> -#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> -#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}> +#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}> #slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}> -#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v4 tt.func @basic_insert_slice_async_v4(%arg0: !tt.ptr {tt.divisibility = 32 : i32}) { %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> @@ -857,15 +977,15 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}> -#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> -#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}> +#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}> #slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}> -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v1 tt.func @basic_insert_slice_async_v1(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> @@ -891,7 +1011,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // It is left to catch changes in AMD compilation pipeline. // PTX: llvm.inline_asm - // PTX-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // PTX: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // PTX: llvm.inline_asm // PTX-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // PTX: llvm.inline_asm @@ -923,14 +1043,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> -#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}> +#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 1], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}> #slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}> -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v1_multictas tt.func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { %off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice2d1> @@ -1012,8 +1132,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK: basic_splat tt.func @basic_splat(%ptr: !tt.ptr) { // CHECK: llvm.mlir.undef @@ -1026,8 +1146,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_store tt.func @basic_store(%ptrs: tensor<256x!tt.ptr, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) { // GCN-NOT: llvm.inline_asm @@ -1045,9 +1165,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK-LABEL: convert_layout_blocked_blocked tt.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { @@ -1092,9 +1212,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK-LABEL: convert_layout_blocked_blocked_vec tt.func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) { @@ -1115,9 +1235,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK-LABEL: convert_layout_blocked_blocked_multi_rep tt.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) { @@ -1144,19 +1264,19 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}> -#mma0 = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[1,1]}> +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma0 = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // PTX-LABEL: convert_dot // This test is not relevant to GCN target, because it is PTX specific tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { %AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0> %BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0> // PTX: llvm.inline_asm - // PTX-SAME: ldmatrix.sync.aligned.m8n8.x4 + // PTX: ldmatrix.sync.aligned.m8n8.x4 // PTX: llvm.inline_asm // PTX-SAME: ldmatrix.sync.aligned.m8n8.x4 %AA_DOT = triton_gpu.convert_layout %AA : (tensor<16x16xf16, #shared0>) -> tensor<16x16xf16, #dot_operand_a> @@ -1174,22 +1294,22 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { } // TODO: problems in MLIR's parser on slice layout -// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> -// module attributes {"triton_gpu.num-warps" = 1 : i32} { +// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +// module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // tt.func @make_range_sliced_layout() { // %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> -// return +// tt.return // } // } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0]}> -#mfma0 = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA=[1,1], isTranspose=false}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma0, kWidth = 8}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma0, kWidth = 8}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma0 = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma0, kWidth = 4}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma0, kWidth = 4}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_dot_mfma tt.func @convert_dot_mfma(%A: tensor<32x32xf16, #blocked0>, %B: tensor<32x32xf16, #blocked0>) { %AA = triton_gpu.convert_layout %A : (tensor<32x32xf16, #blocked0>) -> tensor<32x32xf16, #shared0> @@ -1209,9 +1329,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [1, 0]}> -#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTranspose=false}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK-LABEL: convert_layout_mfma_block tt.func @convert_layout_mfma_blocked(%arg0: tensor<32x32xf32, #mfma>) { @@ -1224,9 +1344,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> -#mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [2, 2]}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // PTX: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // PTX-LABEL: convert_layout_mmav2_block // This test is not relevant to GCN target, because it is PTX specific @@ -1245,9 +1365,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> -#mma = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2]}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 16]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // PTX: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // PTX-LABEL: convert_layout_mmav1_block // This test is not relevant to GCN target, because it is PTX specific @@ -1269,9 +1389,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK-LABEL: convert_layout_blocked_shared tt.func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) { @@ -1286,9 +1406,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice0 tt.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) { // CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr, 3> @@ -1299,9 +1419,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice1 tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { // CHECK-COUNT-8: llvm.load {{.*}} : !llvm.ptr, 3> @@ -1312,9 +1432,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked_to_blocked_ptr tt.func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr, #blocked0>) { // CHECK: llvm.ptrtoint @@ -1330,12 +1450,12 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> -#mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [2, 2]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // PTX-LABEL: matmul_kernel_dot_operand_layout // This test is disabled for GCN target, because it is PTX specific // This test is not relevant to GCN target, because it is PTX specific @@ -1358,12 +1478,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> -#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = 8}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = 8}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = 4}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = 4}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: matmul_kernel_dot_operand_layout_gcn tt.func @matmul_kernel_dot_operand_layout_gcn(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) { @@ -1386,13 +1506,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0]}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> -#mma = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 16]}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // PTX-LABEL: matmul884_kernel_dot_operand_layout // This test is not relevant to GCN target, because it is PTX specific tt.func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, @@ -1413,11 +1533,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#blocked}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul_fmadot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> @@ -1435,12 +1555,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#mma = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[2, 2]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#mma = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=1}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=1}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // PTX-LABEL: matmul_tf32dot // This test is not relevant to GCN target, because it is PTX specific tt.func @matmul_tf32dot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, @@ -1475,8 +1595,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { // GCN-NOT: llvm.inline_asm @@ -1492,15 +1612,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_f32_scalar tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr, %arg1 : i1, %arg2 : f32) { // GCN-NOT: llvm.inline_asm // GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr, f32 // PTX: llvm.icmp "eq" // PTX: llvm.inline_asm - // PTX: llvm.inline_asm - // PTX-SAME: @$3 atom.global.gpu.add.f32 + // PTX-SAME: @$3 atom.global.gpu.relaxed.add.f32 %0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32, sem = 1: i32} : (!tt.ptr, f32, i1) -> f32 tt.return } @@ -1508,8 +1627,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: store_f32 tt.func @store_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) { // GCN-NOT: llvm.inline_asm @@ -1526,7 +1645,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: store_f32_scalar tt.func @store_f32_scalar(%arg0 : !tt.ptr, %arg1 : f32) { // GCN-NOT: llvm.inline_asm @@ -1541,16 +1660,16 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { - +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +// CHECK-LABEL: test_get_program_id tt.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { - %blockidx = tt.get_program_id x : i32 - %blockidy = tt.get_program_id y : i32 - %blockidz = tt.get_program_id z : i32 - // PTX: nvvm.read.ptx.sreg.ctaid.x - // PTX: nvvm.read.ptx.sreg.ctaid.y - // PTX: nvvm.read.ptx.sreg.ctaid.z + %blockidx = tt.get_program_id x: i32 + %blockidy = tt.get_program_id y: i32 + %blockidz = tt.get_program_id z: i32 + // PTX: ctaid.x + // PTX: ctaid.y + // PTX: ctaid.z // GCN: rocdl.workgroup.id.x // GCN: rocdl.workgroup.id.y // GCN: rocdl.workgroup.id.z @@ -1565,18 +1684,62 @@ tt.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 : i32} { +// CHECK-LABEL: test_get_program_id +tt.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { + %blockidx = tt.get_program_id x: i32 + %blockidy = tt.get_program_id y: i32 + %blockidz = tt.get_program_id z : i32 + // PTX: clusterid.x + // PTX: clusterid.y + // PTX: clusterid.z + %v0 = arith.addi %blockidx, %blockidy : i32 + %v1 = arith.addi %v0, %blockidz : i32 + %0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0> + tt.store %a, %0 : tensor<32xi32, #blocked0> + + tt.return +} + +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: test_get_num_program tt.func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { - // PTX: nvvm.read.ptx.sreg.nctaid.x - // PTX: nvvm.read.ptx.sreg.nctaid.y - // PTX: nvvm.read.ptx.sreg.nctaid.z + %blockdimx = tt.get_num_programs {axis=0:i32} : i32 + %blockdimy = tt.get_num_programs {axis=1:i32} : i32 + %blockdimz = tt.get_num_programs {axis=2:i32} : i32 + // PTX: nctaid.x + // PTX: nctaid.y + // PTX: nctaid.z // GCN: rocdl.grid.dim.x // GCN: rocdl.grid.dim.y // GCN: rocdl.grid.dim.z + %v0 = arith.addi %blockdimx, %blockdimy : i32 + %v1 = arith.addi %v0, %blockdimz : i32 + %0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0> + tt.store %a, %0 : tensor<32xi32, #blocked0> + + tt.return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 : i32} { + tt.func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { %blockdimx = tt.get_num_programs {axis=0:i32} : i32 %blockdimy = tt.get_num_programs {axis=1:i32} : i32 %blockdimz = tt.get_num_programs {axis=2:i32} : i32 + // PTX: nclusterid.x + // PTX: nclusterid.y + // PTX: nclusterid.z %v0 = arith.addi %blockdimx, %blockdimy : i32 %v1 = arith.addi %v0, %blockdimz : i32 %0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0> @@ -1587,8 +1750,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: test_index_cache tt.func @test_index_cache() { // PTX: nvvm.read.ptx.sreg.tid.x @@ -1602,9 +1765,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: test_base_index_cache tt.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) { // PTX: nvvm.read.ptx.sreg.tid.x @@ -1618,9 +1781,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: test_index_cache_different_block tt.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) { // PTX: nvvm.read.ptx.sreg.tid.x @@ -1639,12 +1802,12 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#mma = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[2, 2]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#mma = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=1}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=1}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: matmul_tf32_cst_b tt.func @matmul_tf32_cst_b(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a: tensor<32x16xf32, #dot_operand_a>, %c: tensor<32x32xf32, #mma>) { @@ -1664,9 +1827,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: matmul_f16_cst_operands tt.func public @matmul_f16_cst_operands(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> @@ -1703,8 +1866,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: test_s8_to_bf16_conversion tt.func @test_s8_to_bf16_conversion(%in: tensor<32xi8, #blocked>) { // We can't vectorize if we only process @@ -1718,9 +1881,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: test_s8_to_bf16_vectorized_conversion tt.func @test_s8_to_bf16_vectorized_conversion(%in: tensor<16x16xi8, #mma>) { // CHECK-NOT: llvm.sitofp @@ -1741,8 +1904,116 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [2, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 2 : i32} { +// CHECK-LABEL: sum_reduction +// PTX: %[[M:.+]] = llvm.mlir.constant(-1 : i32) : i32 +// PTX: nvvm.redux.sync add %{{.*}}, %[[M]] +// PTX: nvvm.barrier0 +// PTX: shfl.sync.bfly.b32 +// PTX: shfl.sync.bfly.b32 +// PTX: nvvm.barrier0 + +// GCN-COUNT-4: ds_swizzle_b32 +// GCN: llvm.store +// GCN: rocdl.barrier +// GCN: llvm.load +// GCN-COUNT-2: ds_swizzle_b32 +// GCN: llvm.store +// GCN: rocdl.barrier +// GCN: llvm.load +// GCN: rocdl.barrier +// GCN: llvm.store +// GCN: rocdl.barrier +// GCN: llvm.load +// GCN: llvm.store +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @sum_reduction(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<1024> : tensor<1x1xi32, #blocked> + %0 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked1> + %1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<1x1xi32, #blocked> + %3 = arith.muli %2, %cst : tensor<1x1xi32, #blocked> + %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<1x1x!tt.ptr, #blocked> + %5 = tt.addptr %4, %3 : tensor<1x1x!tt.ptr, #blocked>, tensor<1x1xi32, #blocked> + %6 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<1024xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x1024xi32, #blocked> + %8 = tt.broadcast %5 : (tensor<1x1x!tt.ptr, #blocked>) -> tensor<1x1024x!tt.ptr, #blocked> + %9 = tt.addptr %8, %7 : tensor<1x1024x!tt.ptr, #blocked>, tensor<1x1024xi32, #blocked> + %10 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1x1024xi32, #blocked> + %11 = "tt.reduce"(%10) <{axis = 1 : i32}> ({ + ^bb0(%arg2: i32, %arg3: i32): + %15 = arith.addi %arg2, %arg3 : i32 + tt.reduce.return %15 : i32 + }) : (tensor<1x1024xi32, #blocked>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %12 = triton_gpu.convert_layout %11 : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<1xi32, #blocked1> + %13 = tt.splat %arg1 : (!tt.ptr) -> tensor<1x!tt.ptr, #blocked1> + %14 = tt.addptr %13, %0 : tensor<1x!tt.ptr, #blocked1>, tensor<1xi32, #blocked1> + tt.store %14, %12 {cache = 1 : i32, evict = 1 : i32} : tensor<1xi32, #blocked1> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#slice = #triton_gpu.slice<{dim = 1, parent = #blocked}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { + // CHECK-LABEL: reduce_bools + tt.func public @reduce_bools(%arg: tensor<256x2xi1, #blocked>) { + // CHECK: llvm.mlir.addressof @global_smem + %24 = "tt.reduce"(%arg) <{axis = 1 : i32}> ({ + ^bb0(%arg4: i1, %arg5: i1): + %48 = arith.ori %arg4, %arg5 : i1 + tt.reduce.return %48 : i1 + }) : (tensor<256x2xi1, #blocked>) -> tensor<256xi1, #slice> + tt.return + } +} + + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: inline_asm + tt.func public @inline_asm(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked> + %1 = tt.splat %arg0 : (!tt.ptr) -> tensor<512x!tt.ptr, #blocked> + %2 = tt.addptr %1, %0 : tensor<512x!tt.ptr, #blocked>, tensor<512xi32, #blocked> + %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xi8, #blocked> +// CHECK: %{{.*}} = llvm.inline_asm asm_dialect = att "shl.b32 $0, $0, 3;", "=r,r" %{{.*}} : (vector<4xi8>) -> vector<4xi8> + %4 = tt.elementwise_inline_asm "shl.b32 $0, $0, 3;" {constraints = "=r,r", packed_element = 4 : i32, pure = true} %3 : tensor<512xi8, #blocked> -> tensor<512xi8, #blocked> + %5 = tt.splat %arg1 : (!tt.ptr) -> tensor<512x!tt.ptr, #blocked> + %6 = tt.addptr %5, %0 : tensor<512x!tt.ptr, #blocked>, tensor<512xi32, #blocked> + tt.store %6, %4 {cache = 1 : i32, evict = 1 : i32} : tensor<512xi8, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: inline_asm_pack_16bit + tt.func public @inline_asm_pack_16bit(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked> + %1 = tt.splat %arg0 : (!tt.ptr) -> tensor<512x!tt.ptr, #blocked> + %2 = tt.addptr %1, %0 : tensor<512x!tt.ptr, #blocked>, tensor<512xi32, #blocked> + %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xi8, #blocked> +// CHECK: %{{.*}} = llvm.inline_asm asm_dialect = att "shl.b16 $0, $0, 3;", "=h,h" %{{.*}} : (vector<2xi8>) -> vector<2xi8> + %4 = tt.elementwise_inline_asm "shl.b16 $0, $0, 3;" {constraints = "=h,h", packed_element = 2 : i32, pure = true} %3 : tensor<512xi8, #blocked> -> tensor<512xi8, #blocked> + %5 = tt.splat %arg1 : (!tt.ptr) -> tensor<512x!tt.ptr, #blocked> + %6 = tt.addptr %5, %0 : tensor<512x!tt.ptr, #blocked>, tensor<512xi32, #blocked> + tt.store %6, %4 {cache = 1 : i32, evict = 1 : i32} : tensor<512xi8, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [2, 1], order = [1, 0], CTAsPerCGA = [1,1], CTASplitNum = [1,1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { // CHECK-LABEL: atomic_add_f16 tt.func @atomic_add_f16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: f16 {tt.difisibility = 16 : i32}) { %c1_i1 = arith.constant 1 : i1 @@ -1761,5 +2032,3 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { tt.return } } - -// ----- diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir new file mode 100644 index 000000000000..053330d47a49 --- /dev/null +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -0,0 +1,80 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm 2>&1 | FileCheck %s + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: @tma_multicast_no_broadcast + tt.func @tma_multicast_no_broadcast(%basePtr: !tt.ptr {tt.divisibility = 8 : i32}, + %dim0: i64, %dim1: i64, + %stride0: i64, %stride1: i64, + %coord0: i32, %coord1: i32) { + %mbar = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : !tt.ptr + %dst = triton_gpu.alloc_tensor : tensor<1x64x64xf16, #shared> + %c0 = arith.constant 0 : i32 + %src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array} : !tt.ptr, 1> + // CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {operand_segment_sizes = array} : !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32 + %res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr -> tensor<1x64x64xf16, #shared> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: @tma_multicast_const_mask + tt.func @tma_multicast_const_mask(%basePtr: !tt.ptr {tt.divisibility = 8 : i32}, + %dim0: i64, %dim1: i64, + %stride0: i64, %stride1: i64, + %coord0: i32, %coord1: i32) { + %mbar = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : !tt.ptr + %dst = triton_gpu.alloc_tensor : tensor<1x64x64xf16, #shared> + %c0 = arith.constant 0 : i32 + %src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array} : !tt.ptr, 1> + // CHECK: %[[C15:.*]] = llvm.mlir.constant(15 : i16) : i16 + // CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C15]] + %res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr -> tensor<1x64x64xf16, #shared> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 2], CTAOrder = [0, 1]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 2], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: @tma_multicast_variable_mask + tt.func @tma_multicast_variable_mask(%basePtr: !tt.ptr {tt.divisibility = 8 : i32}, + %dim0: i64, %dim1: i64, + %stride0: i64, %stride1: i64, + %coord0: i32, %coord1: i32) { + %mbar = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : !tt.ptr + %dst = triton_gpu.alloc_tensor : tensor<1x64x64xf16, #shared> + %c0 = arith.constant 0 : i32 + %src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array} : !tt.ptr, 1> + // CHECK: nvgpu.cluster_id + // CHECK: nvgpu.tma_load_tiled + %res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr -> tensor<1x64x64xf16, #shared> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: @tma_store + tt.func @tma_store(%basePtr: !tt.ptr {tt.divisibility = 8 : i32}, + %dim0: i64, %dim1: i64, + %stride0: i64, %stride1: i64, + %coord0: i32, %coord1: i32) { + %src = triton_gpu.alloc_tensor : tensor<64x64xf32, #shared> + %c0 = arith.constant 0 : i32 + %dst = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array} : !tt.ptr, 1> + // CHECK: nvgpu.tma_store_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.ptr, !llvm.ptr, i1, i32, i32 + triton_nvidia_gpu.store_async %dst, %src {cache = 1 : i32} : !tt.ptr, 1>, tensor<64x64xf32, #shared> + tt.return + } +} diff --git a/test/NVGPU/test_cga.mlir b/test/NVGPU/test_cga.mlir new file mode 100644 index 000000000000..0b72d92e1ee0 --- /dev/null +++ b/test/NVGPU/test_cga.mlir @@ -0,0 +1,33 @@ +// RUN: triton-opt %s -split-input-file --convert-nv-gpu-to-llvm | FileCheck %s +#SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { + tt.func @test_mbarrier() { + %addr = arith.constant 32 : i32 + %data = arith.constant 123 : i32 + %pred = arith.constant 1 : i1 + %id0 = arith.constant 0 : i32 + %id1 = arith.constant 1 : i32 + // CHECK: llvm.inline_asm + // CHECK: llvm.inline_asm + // CHECK: llvm.inline_asm + nvgpu.cga_barrier_sync + nvgpu.cga_barrier_arrive + nvgpu.cga_barrier_wait + + %ptr = llvm.mlir.null : !llvm.ptr + + // CHECK: llvm.inline_asm + // CHECK: llvm.inline_asm + // CHECK: llvm.inline_asm + // CHECK: llvm.inline_asm + // CHECK: llvm.inline_asm + // CHECK: llvm.mul + // CHECK: llvm.add + // CHECK: llvm.mul + // CHECK: llvm.add + %v = nvgpu.cluster_id + llvm.store %v, %ptr : !llvm.ptr + + tt.return + } +} // end module diff --git a/test/NVGPU/test_mbarrier.mlir b/test/NVGPU/test_mbarrier.mlir new file mode 100644 index 000000000000..b12ea58647c7 --- /dev/null +++ b/test/NVGPU/test_mbarrier.mlir @@ -0,0 +1,19 @@ +// RUN: triton-opt %s -split-input-file --convert-nv-gpu-to-llvm | FileCheck %s +#SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { + tt.func @test_mbarrier() { + %mbarrier = llvm.mlir.null : !llvm.ptr + %pred = arith.constant 1 : i1 + // CHECK: llvm.inline_asm + nvgpu.mbarrier_init %mbarrier, %pred { count = 32 : i32 } : !llvm.ptr + // CHECK: llvm.inline_asm + nvgpu.mbarrier_arrive %mbarrier, %pred {arriveType = 1 : i32}: !llvm.ptr + // CHECK: llvm.inline_asm + nvgpu.mbarrier_arrive %mbarrier, %pred {arriveType = 0 : i32}: !llvm.ptr + // CHECK: llvm.inline_asm + nvgpu.mbarrier_arrive %mbarrier, %pred {arriveType = 2 : i32, txCount = 128 : i32}: !llvm.ptr + // CHECK: llvm.inline_asm + nvgpu.mbarrier_wait %mbarrier, %pred : !llvm.ptr, i1 + tt.return + } +} // end module diff --git a/test/NVGPU/test_tma.mlir b/test/NVGPU/test_tma.mlir new file mode 100644 index 000000000000..4cf7f9b5e838 --- /dev/null +++ b/test/NVGPU/test_tma.mlir @@ -0,0 +1,29 @@ +// RUN: triton-opt %s -split-input-file --convert-nv-gpu-to-llvm | FileCheck %s +#SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { + tt.func @test_tma(%im2colOffsets0 : !llvm.struct<(i16, i16)>, %im2colOffsets1 : !llvm.struct<(i16, i16, i16)>) { + %mbarrier = llvm.mlir.null : !llvm.ptr + %tmaDesc = llvm.mlir.null : !llvm.ptr + %dst = llvm.mlir.null : !llvm.ptr + %l2desc = arith.constant 0 : i64 + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2 : i32 + %c3 = arith.constant 3 : i32 + %c4 = arith.constant 4 : i32 + %pred = arith.constant 1 : i1 + %mask = arith.constant 15 : i16 + + // CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint + // CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint + nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1 {operand_segment_sizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32 + nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operand_segment_sizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32, i32, i32 + + // CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint + // CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint + nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %mask {operand_segment_sizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32, i16 + nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operand_segment_sizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32, i32, i32 + + tt.return + } +} // end module diff --git a/test/NVGPU/test_wgmma.mlir b/test/NVGPU/test_wgmma.mlir new file mode 100644 index 000000000000..bb4844ab5d18 --- /dev/null +++ b/test/NVGPU/test_wgmma.mlir @@ -0,0 +1,45 @@ +// RUN: triton-opt %s -split-input-file --convert-nv-gpu-to-llvm | FileCheck %s +#SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { + tt.func @test_tma(%opC : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) { + %buffer = llvm.mlir.null : !llvm.ptr + %height = arith.constant 16 : i32 + // CHECK: llvm.ptrtoint + // CHECK: llvm.shl + // CHECK: llvm.lshr + // CHECK: llvm.zext + // CHECK: llvm.mul + // CHECK: llvm.lshr + // CHECK: llvm.shl + // CHECK: llvm.lshr + // CHECK: llvm.shl + // CHECK: llvm.or + // CHECK: llvm.shl + // CHECK: llvm.or + // CHECK: llvm.shl + // CHECK: llvm.or + // CHECK: llvm.or + %descA = nvgpu.wgmma_desc_create %buffer, %height {mode = 2 : i32}: (!llvm.ptr, i32) -> (i64) + // CHECK: llvm.ptrtoint + // CHECK: llvm.shl + // CHECK: llvm.lshr + // CHECK: llvm.zext + // CHECK: llvm.mul + // CHECK: llvm.lshr + // CHECK: llvm.shl + // CHECK: llvm.lshr + // CHECK: llvm.shl + // CHECK: llvm.or + // CHECK: llvm.shl + // CHECK: llvm.or + // CHECK: llvm.shl + // CHECK: llvm.or + // CHECK: llvm.or + %descB = nvgpu.wgmma_desc_create %buffer, %height {mode = 2 : i32}: (!llvm.ptr, i32) -> (i64) + + // CHECK-COUNT-32: llvm.extractvalue + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, 1, 1, 1, 0, 1;" + %acc0 = nvgpu.wgmma %descA, %descB, %opC {m=64:i32, n=64:i32, k=16:i32, eltTypeC=7:i32, eltTypeA=4:i32, eltTypeB=4:i32, layoutA=0:i32, layoutB=0:i32} : (i64, i64, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) -> (!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) + tt.return + } +} // end module diff --git a/test/Triton/canonicalize.mlir b/test/Triton/canonicalize.mlir new file mode 100644 index 000000000000..c532b4686be4 --- /dev/null +++ b/test/Triton/canonicalize.mlir @@ -0,0 +1,27 @@ +// RUN: triton-opt %s -split-input-file -canonicalize | FileCheck %s + +// CHECK-LABEL: dead_load +tt.func @dead_load(%ptr: tensor<32x128x!tt.ptr>) { + %mask = arith.constant dense : tensor<32x128xi1> + %other = arith.constant dense<0.00e+00> : tensor<32x128xf16> + // CHECK-NOT: tt.load {{.*}} isVolatile = false + // CHECK: tt.load {{.*}} isVolatile = true + %a = tt.load %ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16> + %b = tt.load %ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : tensor<32x128xf16> + tt.return +} + + +// CHECK-LABEL: make_range +tt.func @make_range() -> (tensor<128x1xi32>, tensor<1xi32>) { + // CHECK-DAG: %[[c:.*]] = arith.constant dense<0> : tensor<128x1xi32> + %a = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32> + %b = tt.expand_dims %a {axis = 1 : i32} : (tensor<1xi32>) -> tensor<1x1xi32> + %c = tt.broadcast %b : (tensor<1x1xi32>) -> tensor<128x1xi32> + + // CHECK-DAG: %[[d:.*]] = arith.constant dense<1> : tensor<1xi32> + %d = tt.make_range {end = 2 : i32, start = 1 : i32} : tensor<1xi32> + + // CHECK-DAG: tt.return %[[c]], %[[d]] : tensor<128x1xi32>, tensor<1xi32> + tt.return %c, %d : tensor<128x1xi32>, tensor<1xi32> +} diff --git a/test/Triton/reorder-broadcast.mlir b/test/Triton/reorder-broadcast.mlir index 0049793e7e30..fbe44dace289 100644 --- a/test/Triton/reorder-broadcast.mlir +++ b/test/Triton/reorder-broadcast.mlir @@ -13,8 +13,8 @@ tt.func @test_splat_elementwise_pattern(%arg0: f32) -> (tensor<128x128xf32>, ten %add = arith.addf %a, %b : tensor<128x128xf32> - // CHECK-NEXT: %[[ptr:.*]] = tt.int_to_ptr %[[c1]] : i64 -> !tt.ptr - // CHECK-NEXT: %{{.*}} = tt.splat %[[ptr]] : (!tt.ptr) -> tensor<128x128x!tt.ptr> + // CHECK-NEXT: %[[ptr:.*]] = tt.int_to_ptr %[[c1]] : i64 -> !tt.ptr + // CHECK-NEXT: %{{.*}} = tt.splat %[[ptr]] : (!tt.ptr) -> tensor<128x128x!tt.ptr> %c1_t = tt.splat %c1 : (i64) -> tensor<128x128xi64> %ptr = tt.int_to_ptr %c1_t : tensor<128x128xi64> -> tensor<128x128x!tt.ptr> diff --git a/test/Triton/rewrite-tensor-pointer.mlir b/test/Triton/rewrite-tensor-pointer.mlir deleted file mode 100644 index d5fe58a7c9a5..000000000000 --- a/test/Triton/rewrite-tensor-pointer.mlir +++ /dev/null @@ -1,83 +0,0 @@ -// RUN: triton-opt %s -triton-rewrite-tensor-pointer | FileCheck %s -tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { - %c31_i32 = arith.constant 31 : i32 - %c127_i32 = arith.constant 127 : i32 - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32> - %c0_i32 = arith.constant 0 : i32 - %c1_i64 = arith.constant 1 : i64 - %c32_i32 = arith.constant 32 : i32 - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %0 = tt.get_program_id x : i32 - %1 = tt.get_program_id y : i32 - %2 = arith.addi %arg3, %c127_i32 : i32 - %3 = arith.divsi %2, %c128_i32 : i32 - %4 = arith.addi %arg4, %c31_i32 : i32 - %5 = arith.divsi %4, %c32_i32 : i32 - %6 = arith.muli %5, %c8_i32 : i32 - %7 = arith.divsi %0, %6 : i32 - %8 = arith.muli %7, %c8_i32 : i32 - %9 = arith.subi %3, %8 : i32 - %10 = arith.cmpi slt, %9, %c8_i32 : i32 - %11 = arith.select %10, %9, %c8_i32 : i32 - %12 = arith.remsi %0, %11 : i32 - %13 = arith.addi %8, %12 : i32 - %14 = arith.remsi %0, %6 : i32 - %15 = arith.divsi %14, %11 : i32 - %16 = arith.muli %13, %c128_i32 : i32 - %17 = arith.muli %1, %c32_i32 : i32 - %18 = arith.extsi %arg3 : i32 to i64 - %19 = arith.extsi %arg5 : i32 to i64 - %20 = arith.extsi %arg6 : i32 to i64 - // CHECK-NOT: tt.make_tensor_ptr - %21 = tt.make_tensor_ptr %arg0, [%18, %19], [%20, %c1_i64], [%16, %17] {order = array} : !tt.ptr> - %22 = arith.muli %15, %c32_i32 : i32 - %23 = arith.extsi %arg4 : i32 to i64 - %24 = arith.extsi %arg7 : i32 to i64 - // CHECK-NOT: tt.make_tensor_ptr - %25 = tt.make_tensor_ptr %arg1, [%19, %23], [%24, %c1_i64], [%17, %22] {order = array} : !tt.ptr> - %26 = arith.addi %arg5, %c31_i32 : i32 - %27 = arith.divsi %26, %c32_i32 : i32 - %28 = arith.index_cast %27 : i32 to index - %29:3 = scf.for %arg9 = %c0 to %28 step %c1 iter_args(%arg10 = %cst, %arg11 = %21, %arg12 = %25) -> (tensor<128x32xf32>, !tt.ptr>, !tt.ptr>) { - // CHECK: tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16> - %55 = tt.load %arg11 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 2 : i32} : !tt.ptr> -> tensor<128x32xf16> - // CHECK: tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16> - %56 = tt.load %arg12 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 2 : i32} : !tt.ptr> -> tensor<32x32xf16> - %57 = tt.dot %55, %56, %arg10 {allowTF32 = true} : tensor<128x32xf16> * tensor<32x32xf16> -> tensor<128x32xf32> - // CHECK-NOT: tt.advance - %58 = tt.advance %arg11, [%c0_i32, %c32_i32] : !tt.ptr> - // CHECK-NOT: tt.advance - %59 = tt.advance %arg12, [%c32_i32, %c0_i32] : !tt.ptr> - scf.yield %57, %58, %59 : tensor<128x32xf32>, !tt.ptr>, !tt.ptr> - } - %30 = arith.truncf %29#0 : tensor<128x32xf32> to tensor<128x32xf16> - %31 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %32 = tt.splat %16 : (i32) -> tensor<128xi32> - %33 = arith.addi %32, %31 : tensor<128xi32> - %34 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %35 = tt.splat %22 : (i32) -> tensor<32xi32> - %36 = arith.addi %35, %34 : tensor<32xi32> - %37 = tt.expand_dims %33 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> - %38 = tt.splat %arg8 : (i32) -> tensor<128x1xi32> - %39 = arith.muli %37, %38 : tensor<128x1xi32> - %40 = tt.expand_dims %36 {axis = 0 : i32} : (tensor<32xi32>) -> tensor<1x32xi32> - %41 = tt.broadcast %39 : (tensor<128x1xi32>) -> tensor<128x32xi32> - %42 = tt.broadcast %40 : (tensor<1x32xi32>) -> tensor<128x32xi32> - %43 = arith.addi %41, %42 : tensor<128x32xi32> - %44 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x32x!tt.ptr> - %45 = tt.addptr %44, %43 : tensor<128x32x!tt.ptr>, tensor<128x32xi32> - %46 = tt.splat %arg3 : (i32) -> tensor<128xi32> - %47 = arith.cmpi slt, %33, %46 : tensor<128xi32> - %48 = tt.expand_dims %47 {axis = 1 : i32} : (tensor<128xi1>) -> tensor<128x1xi1> - %49 = tt.splat %arg4 : (i32) -> tensor<32xi32> - %50 = arith.cmpi slt, %36, %49 : tensor<32xi32> - %51 = tt.expand_dims %50 {axis = 0 : i32} : (tensor<32xi1>) -> tensor<1x32xi1> - %52 = tt.broadcast %48 : (tensor<128x1xi1>) -> tensor<128x32xi1> - %53 = tt.broadcast %51 : (tensor<1x32xi1>) -> tensor<128x32xi1> - %54 = arith.andi %52, %53 : tensor<128x32xi1> - tt.store %45, %30, %54 {cache = 1 : i32, evict = 1 : i32} : tensor<128x32xf16> - tt.return -} diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index 6ce2a17c4830..496fb848e925 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -1,21 +1,20 @@ // RUN: triton-opt %s -split-input-file -tritongpu-coalesce | FileCheck %s -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}> #slice2dim0 = #triton_gpu.slice<{dim = 0, parent = #blocked2}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - -// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[row_layout]]> +// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[row_layout]]> // CHECK: [[load_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[row_layout]]> // CHECK: [[load_other:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[row_layout]]> // CHECK: [[load_val:%.*]] = tt.load [[load_ptr]], [[load_mask]], [[load_other]] {{.*}} : tensor<64x64xf32, [[row_layout]]> -// CHECK: [[store_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[col_layout]]> +// CHECK: [[store_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[col_layout]]> // CHECK: [[store_val:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]> // CHECK: [[store_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]> // CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]] @@ -51,3 +50,22 @@ tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { + +// CHECK: [[NEW_LOADED_LAYOUT:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @load_tensor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i64 + %0 = arith.extsi %arg1 : i32 to i64 + %1 = arith.extsi %arg2 : i32 to i64 + %2 = tt.make_tensor_ptr %arg0, [%0, %1], [%1, %c1], [%c0, %c0] { order = array } : !tt.ptr, 1> + // CHECK: !tt.ptr, 1> -> tensor<32x32xf32, [[NEW_LOADED_LAYOUT]]> + %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<32x32xf32, #blocked> + tt.return +} + +} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index e123841e42cc..ec6925b373cb 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -1,16 +1,12 @@ // RUN: triton-opt %s -split-input-file -tritongpu-remove-layout-conversions 2>&1 | FileCheck %s -#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#layout2 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}> +#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -// CHECK: [[$target_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -// CHECK: [[$row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> -// CHECK: [[$col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> -// CHECK: [[$col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -// CHECK-LABEL: cst module attributes {"triton_gpu.num-warps" = 4 : i32} { +// CHECK: [[$target_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +// CHECK-LABEL: cst tt.func @cst() -> tensor<1024xi32, #layout1> { %cst = arith.constant dense<0> : tensor<1024xi32, #layout0> %1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> @@ -81,6 +77,36 @@ tt.func @remat_fast_load(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { tt.return } +// Hoist the convert on top of ext to make it cheaper. +// CHECK-LABEL: hoist_above_ext +tt.func @hoist_above_ext(%arg0: tensor<1024xf16, #layout0>, %arg1: f32) -> tensor<1024xf32, #layout1> { +// CHECK: %[[CVT:.+]] = triton_gpu.convert_layout +// CHECK: arith.extf %[[CVT]] +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: tt.return + %0 = arith.extf %arg0 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0> + %1 = tt.splat %arg1 : (f32) -> tensor<1024xf32, #layout0> + %2 = arith.addf %0, %1 : tensor<1024xf32, #layout0> + %3 = triton_gpu.convert_layout %2 : (tensor<1024xf32, #layout0>) -> tensor<1024xf32, #layout1> + tt.return %3 : tensor<1024xf32, #layout1> +} + +// CHECK-LABEL: hoist_above_ext2 +tt.func @hoist_above_ext2(%arg0: tensor<1024xf16, #layout0>, %arg1: f16) -> tensor<1024xf32, #layout1> { +// CHECK: %[[CVT:.+]] = triton_gpu.convert_layout +// CHECK: arith.extf %[[CVT]] +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: tt.return + %0 = arith.extf %arg0 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0> + %1 = tt.splat %arg1 : (f16) -> tensor<1024xf16, #layout0> + %2 = arith.extf %1 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0> + %3 = arith.addf %0, %2 : tensor<1024xf32, #layout0> + %4 = triton_gpu.convert_layout %3 : (tensor<1024xf32, #layout0>) -> tensor<1024xf32, #layout1> + tt.return %4 : tensor<1024xf32, #layout1> +} + + + // CHECK-LABEL: if tt.func @if(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { // CHECK-NOT: triton_gpu.convert_layout @@ -165,13 +191,20 @@ tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = } -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #slice2dim0 = #triton_gpu.slice<{dim = 0, parent = #blocked2}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> -#blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked5 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> + +// CHECK: [[$row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +// CHECK: [[$col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +// CHECK: [[$col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> // CHECK-LABEL: transpose module attributes {"triton_gpu.num-warps" = 4 : i32} { @@ -220,14 +253,16 @@ tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 module attributes {"triton_gpu.num-warps" = 4 : i32} { tt.func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %arg4: i32) { // CHECK-NOT: triton_gpu.convert_layout - // CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr, [[$row_layout]]>) + // CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr, [[$row_layout]]>) // CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[$row_layout]]> // CHECK-NEXT: {{.*}} = arith.addf {{.*}} : tensor<64x64xf32, [[$row_layout]]> - // CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr, [[$row_layout]]>, tensor<64x64xi32, [[$row_layout]]> - // CHECK-NEXT: scf.yield {{.*}} : tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr, [[$row_layout]]> + // CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr, [[$row_layout]]>, tensor<64x64xi32, [[$row_layout]]> + // CHECK-NEXT: scf.yield {{.*}} : tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr, [[$row_layout]]> // CHECK-NEXT: } - // CHECK-NEXT: {{.*}} = triton_gpu.convert_layout [[loop_ret]]#0 : (tensor<64x64xf32, [[$row_layout]]>) -> tensor<64x64xf32, [[$col_layout_novec]]> // CHECK-NOT: triton_gpu.convert_layout + // CHECK: {{.*}} = triton_gpu.convert_layout [[loop_ret]]#0 : (tensor<64x64xf32, [[$row_layout]]>) -> tensor<64x64xf32, [[$col_layout_novec]]> + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: tt.return %cst = arith.constant dense : tensor<64x64xi1, #blocked1> %cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1> %c1 = arith.constant 1 : index @@ -273,6 +308,19 @@ tt.func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, } // CHECK-LABEL: loop_if +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: scf.for +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: scf.if +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: scf.yield +// CHECK: else +// CHECK: scf.yield +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: scf.yield +// CHECK: triton_gpu.convert_layout +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: tt.store module attributes {"triton_gpu.num-warps" = 4 : i32} { tt.func @loop_if(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %arg4: i32) { %cst = arith.constant dense : tensor<64x64xi1, #blocked1> @@ -333,28 +381,28 @@ tt.func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr %c256_i32 = arith.constant 256 : i32 %0 = tt.get_program_id x : i32 %1 = arith.muli %0, %c256_i32 : i32 - %2 = tt.splat %1 : (i32) -> tensor<256xi32, #layout1> - %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #layout1> - %4 = tt.splat %1 : (i32) -> tensor<256xi32, #layout1> - %5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #layout1> - %6 = tt.splat %1 : (i32) -> tensor<256xi32, #layout1> - %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #layout1> - %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #layout1> - %9 = arith.addi %6, %7 : tensor<256xi32, #layout1> - %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #layout1> - %11 = arith.addi %4, %5 : tensor<256xi32, #layout1> - %12 = tt.addptr %8, %9 : tensor<256x!tt.ptr, #layout1>, tensor<256xi32, #layout1> - %13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #layout1> - %14 = triton_gpu.convert_layout %13 : (tensor<256xf32, #layout1>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>> - %15 = tt.addptr %10, %11 : tensor<256x!tt.ptr, #layout1>, tensor<256xi32, #layout1> - %16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #layout1> - %17 = triton_gpu.convert_layout %16 : (tensor<256xf32, #layout1>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>> - %18 = arith.addf %14, %17 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>> - %19 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #layout1> - %20 = arith.addi %2, %3 : tensor<256xi32, #layout1> - %21 = tt.addptr %19, %20 : tensor<256x!tt.ptr, #layout1>, tensor<256xi32, #layout1> - %22 = triton_gpu.convert_layout %18 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>) -> tensor<256xf32, #layout1> - tt.store %21, %22 : tensor<256xf32, #layout1> + %2 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked5> + %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked5> + %4 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked5> + %5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked5> + %6 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked5> + %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked5> + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked5> + %9 = arith.addi %6, %7 : tensor<256xi32, #blocked5> + %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked5> + %11 = arith.addi %4, %5 : tensor<256xi32, #blocked5> + %12 = tt.addptr %8, %9 : tensor<256x!tt.ptr, #blocked5>, tensor<256xi32, #blocked5> + %13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked5> + %14 = triton_gpu.convert_layout %13 : (tensor<256xf32, #blocked5>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>> + %15 = tt.addptr %10, %11 : tensor<256x!tt.ptr, #blocked5>, tensor<256xi32, #blocked5> + %16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked5> + %17 = triton_gpu.convert_layout %16 : (tensor<256xf32, #blocked5>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>> + %18 = arith.addf %14, %17 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>> + %19 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked5> + %20 = arith.addi %2, %3 : tensor<256xi32, #blocked5> + %21 = tt.addptr %19, %20 : tensor<256x!tt.ptr, #blocked5>, tensor<256xi32, #blocked5> + %22 = triton_gpu.convert_layout %18 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>>) -> tensor<256xf32, #blocked5> + tt.store %21, %22 : tensor<256xf32, #blocked5> tt.return } } @@ -902,14 +950,14 @@ tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! // ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> // cmpf and cmpi have different operands and result types // CHECK-LABEL: cmp -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 4], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [0, 1]}> -#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { %c64 = arith.constant 64 : index @@ -1087,10 +1135,10 @@ tt.func public @if_no_tensor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, % // Check if the SimplifyReduceCvt rewriter pattern doesn't hang. // CHECK-LABEL: reduce_cvt // CHECK-NOT: triton_gpu.convert_layout -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [2, 1], order = [1, 0]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [2, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> module attributes {"triton_gpu.num-warps" = 2 : i32} { tt.func public @reduce_cvt1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i32) { %cst = arith.constant dense<0> : tensor<1x2xi32, #blocked> @@ -1122,18 +1170,18 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { // ----- -// Check if SimplifyReduceCvt handles convert_layout lifted from the for loop. // CHECK-LABEL: reduce_cvt2 // Match the reduction // CHECK: tt.reduce // CHECK-SAME: axis = 1 // CHECK: (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -// CHECK-NEXT: triton_gpu.convert_layout +// CHECK: triton_gpu.convert_layout // CHECK-NOT: triton_gpu.convert_layout -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +// CHECK: tt.return +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { tt.func public @reduce_cvt2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}) { %cst = arith.constant dense<0.000000e+00> : tensor<1x256xf32, #blocked> @@ -1344,6 +1392,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // Check if MoveConvertOutOfLoop hangs because of adding additional conversions // CHECK-LABEL: loop_print // CHECK-NOT: triton_gpu.convert_layout +// CHECK: tt.return #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> @@ -1499,3 +1548,291 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war tt.return } } + + +// ----- + +// Check that we don't have extra convert for flash attention IR. +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [4, 1, 8], warpsPerCTA = [4, 1, 1], order = [1, 2, 0], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [1, 0, 2]}> +#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [1, 4, 8], warpsPerCTA = [1, 4, 1], order = [0, 2, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 1, 2]}> +#blocked6 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked7 = #triton_gpu.blocked<{sizePerThread = [8, 1, 1], threadsPerWarp = [8, 1, 4], warpsPerCTA = [1, 1, 4], order = [1, 0, 2], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [1, 0, 2]}> +#blocked8 = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 8, 4], warpsPerCTA = [1, 1, 4], order = [0, 1, 2], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 1, 2]}> +#blocked9 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @attention_fw(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg11: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg12: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg13: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg14: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg15: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg16: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg17: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg18: i32, %arg19: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg20: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg21: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %c0_i64 = arith.constant 0 : i64 + %c64_i64 = arith.constant 64 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked> + %cst_0 = arith.constant dense<0xFF800000> : tensor<128xf32, #blocked1> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128xf32, #blocked1> + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #blocked2> + %cst_3 = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.muli %1, %arg7 : i32 + %3 = arith.muli %1, %arg10 : i32 + %4 = tt.addptr %arg0, %2 : !tt.ptr, i32 + %5 = arith.muli %0, %c128_i32 : i32 + %6 = arith.extsi %arg8 : i32 to i64 + %7 = arith.extsi %5 : i32 to i64 + %8 = tt.addptr %arg1, %3 : !tt.ptr, i32 + %9 = arith.addi %arg20, %arg21 : i32 + %10 = arith.extsi %arg11 : i32 to i64 + %11 = tt.addptr %arg2, %3 : !tt.ptr, i32 + %12 = arith.extsi %arg14 : i32 to i64 + %13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked1> + %14 = tt.splat %5 : (i32) -> tensor<128xi32, #blocked1> + %15 = arith.addi %14, %13 : tensor<128xi32, #blocked1> + %16 = arith.mulf %arg3, %cst_3 : f32 + %17 = tt.splat %4 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked3> + %18 = tt.splat %7 : (i64) -> tensor<128xi64, #blocked3> + %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked3> + %20 = arith.extsi %19 : tensor<128xi32, #blocked3> to tensor<128xi64, #blocked3> + %21 = arith.addi %18, %20 : tensor<128xi64, #blocked3> + %22 = triton_gpu.convert_layout %21 : (tensor<128xi64, #blocked3>) -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %23 = tt.expand_dims %22 {axis = 1 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<128x1xi64, #blocked4> + %24 = tt.splat %6 : (i64) -> tensor<128x1xi64, #blocked4> + %25 = arith.muli %23, %24 : tensor<128x1xi64, #blocked4> + %26 = tt.broadcast %25 : (tensor<128x1xi64, #blocked4>) -> tensor<128x64xi64, #blocked4> + %27 = triton_gpu.convert_layout %26 : (tensor<128x64xi64, #blocked4>) -> tensor<128x64xi64, #blocked3> + %28 = tt.addptr %17, %27 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> + %29 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3> + %30 = arith.extsi %29 : tensor<64xi32, #blocked3> to tensor<64xi64, #blocked3> + %31 = triton_gpu.convert_layout %30 : (tensor<64xi64, #blocked3>) -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> + %32 = tt.expand_dims %31 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked5}>>) -> tensor<1x64xi64, #blocked5> + %33 = tt.broadcast %32 : (tensor<1x64xi64, #blocked5>) -> tensor<128x64xi64, #blocked5> + %34 = triton_gpu.convert_layout %33 : (tensor<128x64xi64, #blocked5>) -> tensor<128x64xi64, #blocked3> + %35 = tt.addptr %28, %34 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> + %36 = tt.load %35 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked3> + %37 = triton_gpu.convert_layout %36 : (tensor<128x64xf16, #blocked3>) -> tensor<128x64xf16, #blocked2> + %38 = tt.splat %16 : (f32) -> tensor<128x64xf32, #blocked2> + %39 = arith.extf %37 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2> + %40 = arith.mulf %39, %38 : tensor<128x64xf32, #blocked2> + %41 = arith.truncf %40 : tensor<128x64xf32, #blocked2> to tensor<128x64xf16, #blocked2> +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: scf.for +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: triton_gpu.convert_layout %{{.*}} #triton_gpu.dot_op +// CHECK: triton_gpu.convert_layout %{{.*}} #triton_gpu.dot_op +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: tt.dot +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: triton_gpu.convert_layout %{{.*}} #triton_gpu.dot_op +// CHECK: triton_gpu.convert_layout %{{.*}} #triton_gpu.dot_op +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: tt.dot +// CHECK: scf.yield + %42:5 = scf.for %arg22 = %c0_i32 to %9 step %c64_i32 iter_args(%arg23 = %cst_2, %arg24 = %cst_1, %arg25 = %cst_0, %arg26 = %c0_i64, %arg27 = %c0_i64) -> (tensor<128x64xf32, #blocked2>, tensor<128xf32, #blocked1>, tensor<128xf32, #blocked1>, i64, i64) : i32 { + %78 = tt.splat %8 : (!tt.ptr) -> tensor<64x64x!tt.ptr, #blocked6> + %79 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked6> + %80 = arith.extsi %79 : tensor<64xi32, #blocked6> to tensor<64xi64, #blocked6> + %81 = triton_gpu.convert_layout %80 : (tensor<64xi64, #blocked6>) -> tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> + %82 = tt.expand_dims %81 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked7}>>) -> tensor<64x1xi64, #blocked7> + %83 = tt.broadcast %82 : (tensor<64x1xi64, #blocked7>) -> tensor<64x64xi64, #blocked7> + %84 = triton_gpu.convert_layout %83 : (tensor<64x64xi64, #blocked7>) -> tensor<64x64xi64, #blocked6> + %85 = tt.addptr %78, %84 : tensor<64x64x!tt.ptr, #blocked6>, tensor<64x64xi64, #blocked6> + %86 = tt.splat %arg26 : (i64) -> tensor<64xi64, #blocked6> + %87 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked6> + %88 = arith.extsi %87 : tensor<64xi32, #blocked6> to tensor<64xi64, #blocked6> + %89 = arith.addi %86, %88 : tensor<64xi64, #blocked6> + %90 = triton_gpu.convert_layout %89 : (tensor<64xi64, #blocked6>) -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked8}>> + %91 = tt.expand_dims %90 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked8}>>) -> tensor<1x64xi64, #blocked8> + %92 = tt.splat %10 : (i64) -> tensor<1x64xi64, #blocked8> + %93 = arith.muli %91, %92 : tensor<1x64xi64, #blocked8> + %94 = tt.broadcast %93 : (tensor<1x64xi64, #blocked8>) -> tensor<64x64xi64, #blocked8> + %95 = triton_gpu.convert_layout %94 : (tensor<64x64xi64, #blocked8>) -> tensor<64x64xi64, #blocked6> + %96 = tt.addptr %85, %95 : tensor<64x64x!tt.ptr, #blocked6>, tensor<64x64xi64, #blocked6> + %97 = tt.load %96 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked6> + %98 = tt.splat %11 : (!tt.ptr) -> tensor<64x64x!tt.ptr, #blocked3> + %99 = tt.splat %arg27 : (i64) -> tensor<64xi64, #blocked3> + %100 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3> + %101 = arith.extsi %100 : tensor<64xi32, #blocked3> to tensor<64xi64, #blocked3> + %102 = arith.addi %99, %101 : tensor<64xi64, #blocked3> + %103 = triton_gpu.convert_layout %102 : (tensor<64xi64, #blocked3>) -> tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %104 = tt.expand_dims %103 {axis = 1 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64x1xi64, #blocked4> + %105 = tt.splat %12 : (i64) -> tensor<64x1xi64, #blocked4> + %106 = arith.muli %104, %105 : tensor<64x1xi64, #blocked4> + %107 = tt.broadcast %106 : (tensor<64x1xi64, #blocked4>) -> tensor<64x64xi64, #blocked4> + %108 = triton_gpu.convert_layout %107 : (tensor<64x64xi64, #blocked4>) -> tensor<64x64xi64, #blocked3> + %109 = tt.addptr %98, %108 : tensor<64x64x!tt.ptr, #blocked3>, tensor<64x64xi64, #blocked3> + %110 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3> + %111 = arith.extsi %110 : tensor<64xi32, #blocked3> to tensor<64xi64, #blocked3> + %112 = triton_gpu.convert_layout %111 : (tensor<64xi64, #blocked3>) -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> + %113 = tt.expand_dims %112 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked5}>>) -> tensor<1x64xi64, #blocked5> + %114 = tt.broadcast %113 : (tensor<1x64xi64, #blocked5>) -> tensor<64x64xi64, #blocked5> + %115 = triton_gpu.convert_layout %114 : (tensor<64x64xi64, #blocked5>) -> tensor<64x64xi64, #blocked3> + %116 = tt.addptr %109, %115 : tensor<64x64x!tt.ptr, #blocked3>, tensor<64x64xi64, #blocked3> + %117 = tt.load %116 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked3> + %118 = triton_gpu.convert_layout %41 : (tensor<128x64xf16, #blocked2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %119 = triton_gpu.convert_layout %97 : (tensor<64x64xf16, #blocked6>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %120 = tt.dot %118, %119, %cst {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf16, #blocked> + %121 = triton_gpu.convert_layout %120 : (tensor<128x64xf16, #blocked>) -> tensor<128x64xf16, #blocked2> + %122 = arith.extf %121 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2> + %123 = "tt.reduce"(%122) <{axis = 1 : i32}> ({ + ^bb0(%arg28: f32, %arg29: f32): + %153 = arith.maxf %arg28, %arg29 : f32 + tt.reduce.return %153 : f32 + }) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %124 = triton_gpu.convert_layout %123 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xf32, #blocked1> + %125 = arith.maxf %arg25, %124 : tensor<128xf32, #blocked1> + %126 = arith.subf %arg25, %125 : tensor<128xf32, #blocked1> + %127 = tt.extern_elementwise %126 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #blocked1> + %128 = triton_gpu.convert_layout %125 : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> + %129 = tt.expand_dims %128 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>>) -> tensor<128x1xf32, #blocked9> + %130 = triton_gpu.convert_layout %129 : (tensor<128x1xf32, #blocked9>) -> tensor<128x1xf32, #blocked2> + %131 = tt.broadcast %130 : (tensor<128x1xf32, #blocked2>) -> tensor<128x64xf32, #blocked2> + %132 = arith.subf %122, %131 : tensor<128x64xf32, #blocked2> + %133 = tt.extern_elementwise %132 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #blocked2> + %134 = arith.mulf %arg24, %cst_1 : tensor<128xf32, #blocked1> + %135 = arith.addf %134, %127 : tensor<128xf32, #blocked1> + %136 = triton_gpu.convert_layout %135 : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> + %137 = tt.expand_dims %136 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>>) -> tensor<128x1xf32, #blocked9> + %138 = triton_gpu.convert_layout %137 : (tensor<128x1xf32, #blocked9>) -> tensor<128x1xf32, #blocked2> + %139 = tt.broadcast %138 : (tensor<128x1xf32, #blocked2>) -> tensor<128x64xf32, #blocked2> + %140 = arith.mulf %arg23, %139 : tensor<128x64xf32, #blocked2> + %141 = arith.truncf %133 : tensor<128x64xf32, #blocked2> to tensor<128x64xf16, #blocked2> + %142 = triton_gpu.convert_layout %141 : (tensor<128x64xf16, #blocked2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %143 = triton_gpu.convert_layout %117 : (tensor<64x64xf16, #blocked3>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %144 = triton_gpu.convert_layout %140 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #blocked> + %145 = tt.dot %142, %143, %144 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked> + %146 = triton_gpu.convert_layout %145 : (tensor<128x64xf32, #blocked>) -> tensor<128x64xf32, #blocked2> + %147 = arith.mulf %arg24, %127 : tensor<128xf32, #blocked1> + %148 = "tt.reduce"(%133) <{axis = 1 : i32}> ({ + ^bb0(%arg28: f32, %arg29: f32): + %153 = arith.addf %arg28, %arg29 : f32 + tt.reduce.return %153 : f32 + }) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %149 = triton_gpu.convert_layout %148 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xf32, #blocked1> + %150 = arith.addf %147, %149 : tensor<128xf32, #blocked1> + %151 = arith.addi %arg26, %c64_i64 : i64 + %152 = arith.addi %arg27, %c64_i64 : i64 + scf.yield %146, %150, %125, %151, %152 : tensor<128x64xf32, #blocked2>, tensor<128xf32, #blocked1>, tensor<128xf32, #blocked1>, i64, i64 + } + %43 = triton_gpu.convert_layout %42#1 : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> + %44 = tt.expand_dims %43 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>>) -> tensor<128x1xf32, #blocked9> + %45 = triton_gpu.convert_layout %44 : (tensor<128x1xf32, #blocked9>) -> tensor<128x1xf32, #blocked2> + %46 = tt.broadcast %45 : (tensor<128x1xf32, #blocked2>) -> tensor<128x64xf32, #blocked2> + %47 = arith.divf %42#0, %46 : tensor<128x64xf32, #blocked2> + %48 = arith.muli %1, %arg20 : i32 + %49 = tt.addptr %arg4, %48 : !tt.ptr, i32 + %50 = tt.splat %49 : (!tt.ptr) -> tensor<128x!tt.ptr, #blocked1> + %51 = tt.addptr %50, %15 : tensor<128x!tt.ptr, #blocked1>, tensor<128xi32, #blocked1> + %52 = tt.extern_elementwise %42#1 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_log2f"} : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #blocked1> + %53 = arith.addf %42#2, %52 : tensor<128xf32, #blocked1> + tt.store %51, %53 {cache = 1 : i32, evict = 1 : i32} : tensor<128xf32, #blocked1> + %54 = tt.addptr %arg5, %2 : !tt.ptr, i32 + %55 = arith.extsi %arg17 : i32 to i64 + %56 = arith.extsi %5 : i32 to i64 + %57 = arith.truncf %47 : tensor<128x64xf32, #blocked2> to tensor<128x64xf16, #blocked2> + %58 = triton_gpu.convert_layout %57 : (tensor<128x64xf16, #blocked2>) -> tensor<128x64xf16, #blocked3> + %59 = tt.splat %54 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked3> + %60 = tt.splat %56 : (i64) -> tensor<128xi64, #blocked3> + %61 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked3> + %62 = arith.extsi %61 : tensor<128xi32, #blocked3> to tensor<128xi64, #blocked3> + %63 = arith.addi %60, %62 : tensor<128xi64, #blocked3> + %64 = triton_gpu.convert_layout %63 : (tensor<128xi64, #blocked3>) -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %65 = tt.expand_dims %64 {axis = 1 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<128x1xi64, #blocked4> + %66 = tt.splat %55 : (i64) -> tensor<128x1xi64, #blocked4> + %67 = arith.muli %65, %66 : tensor<128x1xi64, #blocked4> + %68 = tt.broadcast %67 : (tensor<128x1xi64, #blocked4>) -> tensor<128x64xi64, #blocked4> + %69 = triton_gpu.convert_layout %68 : (tensor<128x64xi64, #blocked4>) -> tensor<128x64xi64, #blocked3> + %70 = tt.addptr %59, %69 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> + %71 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3> + %72 = arith.extsi %71 : tensor<64xi32, #blocked3> to tensor<64xi64, #blocked3> + %73 = triton_gpu.convert_layout %72 : (tensor<64xi64, #blocked3>) -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked5}>> + %74 = tt.expand_dims %73 {axis = 0 : i32} : (tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked5}>>) -> tensor<1x64xi64, #blocked5> + %75 = tt.broadcast %74 : (tensor<1x64xi64, #blocked5>) -> tensor<128x64xi64, #blocked5> + %76 = triton_gpu.convert_layout %75 : (tensor<128x64xi64, #blocked5>) -> tensor<128x64xi64, #blocked3> + %77 = tt.addptr %70, %76 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> + tt.store %77, %58 {cache = 1 : i32, evict = 1 : i32} : tensor<128x64xf16, #blocked3> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +// CHECK-LABEL: axis_mismatch +tt.func @axis_mismatch(%arg0: f32) -> tensor<1xf32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> { +// CHECK: %[[R:.+]] = "tt.reduce"(%0) <{axis = 1 : i32}> +// CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[R]] +// CHECK: tt.return %[[C]] + %0 = tt.splat %arg0 : (f32) -> tensor<1x16xf32, #blocked> + %1 = "tt.reduce"(%0) <{axis = 1 : i32}> ({ + ^bb0(%arg9: f32, %arg10: f32): + %60 = arith.addf %arg9, %arg10 : f32 + tt.reduce.return %60 : f32 + }) : (tensor<1x16xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %2 = triton_gpu.convert_layout %1 : (tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<1xf32, #blocked1> + %3 = triton_gpu.convert_layout %2 : (tensor<1xf32, #blocked1>) -> tensor<1xf32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + tt.return %3: tensor<1xf32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { +// CHECK-LABEL: reduce_to_scalar +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: tt.return +tt.func @reduce_to_scalar(%ptr: tensor<1024x!tt.ptr, #blocked>) -> (f32, i32) { + %0 = tt.load %ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> + %1 = triton_gpu.convert_layout %0 : (tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked1> + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked1> + %3:2 = "tt.reduce"(%1, %2) <{axis = 0 : i32}> ({ + ^bb0(%arg7: f32, %arg8: i32, %arg9: f32, %arg10: i32): + %51 = "triton_gpu.cmpf"(%arg7, %arg9) <{predicate = 1 : i64}> : (f32, f32) -> i1 + %52 = "triton_gpu.cmpi"(%arg8, %arg10) <{predicate = 2 : i64}> : (i32, i32) -> i1 + %53 = arith.andi %51, %52 : i1 + %54 = "triton_gpu.cmpf"(%arg7, %arg9) <{predicate = 2 : i64}> : (f32, f32) -> i1 + %55 = arith.ori %54, %53 : i1 + %56 = arith.select %55, %arg7, %arg9 : f32 + %57 = arith.select %55, %arg8, %arg10 : i32 + tt.reduce.return %56, %57 : f32, i32 + }) : (tensor<1024xf32, #blocked1>, tensor<1024xi32, #blocked1>) -> (f32, i32) + tt.return %3#0, %3#1: f32, i32 +} +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { +// CHECK-LABEL: whileop +// CHECK: %[[L:.+]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> +// CHECK: %[[W:.+]] = scf.while (%[[I:.+]] = %[[L]], %{{.*}} = %{{.*}}) : (tensor<1024xf32, #blocked>, i1) -> tensor<1024xf32, #blocked> { +// CHECK: scf.condition(%{{.*}}) %[[I]] : tensor<1024xf32, #blocked> +// CHECK: } do { +// CHECK: ^bb0(%[[ARG1:.+]]: tensor<1024xf32, #blocked>): +// CHECK: %[[ADD:.+]] = arith.addf %[[ARG1]], %[[ARG1]] : tensor<1024xf32, #blocked> +// CHECK: scf.yield %[[ADD]], %{{.*}} : tensor<1024xf32, #blocked>, i1 +// CHECK: } +// CHECK: tt.store %{{.*}}, %[[W]] {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf32, #blocked> +tt.func @whileop(%ptr: tensor<1024x!tt.ptr, #blocked>, %cond: i1) { + %0 = tt.load %ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> + %1 = triton_gpu.convert_layout %0 : (tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked1> + %2 = scf.while (%arg0 = %1, %arg1 = %cond) : (tensor<1024xf32, #blocked1>, i1) -> (tensor<1024xf32, #blocked1>) { + scf.condition(%arg1) %arg0 : tensor<1024xf32, #blocked1> + } do { + ^bb0(%arg0: tensor<1024xf32, #blocked1>): + %4 = triton_gpu.convert_layout %arg0 : (tensor<1024xf32, #blocked1>) -> tensor<1024xf32, #blocked> + %5 = arith.addf %4, %4 : tensor<1024xf32, #blocked> + %6 = triton_gpu.convert_layout %5 : (tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked1> + scf.yield %6, %cond : tensor<1024xf32, #blocked1>, i1 + } + %3 = triton_gpu.convert_layout %2 : (tensor<1024xf32, #blocked1>) -> tensor<1024xf32, #blocked> + tt.store %ptr, %3 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf32, #blocked> + tt.return +} +} diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 9c6e07ff7cba..1fbfaa9d4d07 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-optimize-dot-operands -tritongpu-remove-layout-conversions -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-optimize-dot-operands -canonicalize | FileCheck %s #Cv2 = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> #Av2k1 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv2, kWidth=1}> @@ -71,9 +71,9 @@ tt.func @succeeds_if_arg_is_not_convert_layout( #mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { -// CHECK: #[[BA:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: #[[BB:.*]] = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -// CHECK: #[[MMA:.*]] = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +// CHECK: #[[BA:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [], CTASplitNum = [], CTAOrder = []}> +// CHECK: #[[BB:.*]] = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [], CTASplitNum = [], CTAOrder = []}> +// CHECK: #[[MMA:.*]] = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], CTAsPerCGA = [], CTASplitNum = [], CTAOrder = [], instrShape = []}> // CHECK: tt.func @push_convert_both_operands // CHECK: %[[ALOAD:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #[[BA]]> @@ -106,9 +106,9 @@ tt.func @push_convert_both_operands( #mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { -// CHECK: #[[BA:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: #[[BB:.*]] = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -// CHECK: #[[MMA:.*]] = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +// CHECK: #[[BA:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [], CTASplitNum = [], CTAOrder = []}> +// CHECK: #[[BB:.*]] = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [], CTASplitNum = [], CTAOrder = []}> +// CHECK: #[[MMA:.*]] = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], CTAsPerCGA = [], CTASplitNum = [], CTAOrder = [], instrShape = []}> // CHECK: tt.func @update_kwidth_slice // CHECK: %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir new file mode 100644 index 000000000000..7c54ce39b839 --- /dev/null +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -0,0 +1,305 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=compute-capability=90 -canonicalize | FileCheck %s + +// 4 warps +// matmul: 128x32 @ 32x128 -> 128x128 +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> +#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> +#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> + +// CHECK: tt.func @matmul_loop +// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 +// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 +// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] +// CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] +// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_A]] +// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] +// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_B]] +// CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] +// CHECK-DAG: %[[LOOP_COND_1:.*]] = arith.cmpi slt, %[[IV_1]], %[[UB]] +// CHECK-DAG: %[[LOOP_COND_1_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_1]] +// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[LOOP_COND_1_SPLAT_A]] +// CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]] +// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[LOOP_COND_1_SPLAT_B]] +// CHECK: triton_gpu.async_wait {num = 2 : i32} +// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] +// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}} +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} +// CHECK: triton_gpu.async_wait {num = 2 : i32} +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][{{.*}}, 0, 0] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][{{.*}}, 0, 0] +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +tt.func @matmul_loop(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) { + // A ptrs + %a_ptr_splat = tt.splat %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : (tensor<32xi32, #ALs0>) -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : (tensor<1x32xi32, #AL>) -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> + %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> + %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return +} +} + +// ----- + +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> +#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> +#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +// CHECK: tt.func @matmul_loop_nested +// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 +// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 +// CHECK: scf.for +// CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] +// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] +// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] +// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] +// CHECK: triton_gpu.async_wait {num = 2 : i32} +// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] +// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}} +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} +// CHECK: triton_gpu.async_wait {num = 2 : i32} +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][{{.*}}, 0, 0] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][{{.*}}, 0, 0] +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) { + scf.for %iv0 = %lb to %ub step %step { + // A ptrs + %a_ptr_splat = tt.splat %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : (tensor<32xi32, #ALs0>) -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : (tensor<1x32xi32, #AL>) -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> + %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> + %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + } + tt.return +} +} + +// ----- + +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> +#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> +#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +// CHECK: tt.func @matmul_loop_single_pipeline +// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 +// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 +// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] +// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] +// CHECK: triton_gpu.async_wait {num = 1 : i32} +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] +// CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}} +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} +// CHECK: triton_gpu.async_wait {num = 1 : i32} +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][{{.*}}, 0, 0] +// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, {{.*}} +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) { + // A ptrs + %a_ptr_splat = tt.splat %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : (tensor<32xi32, #ALs0>) -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : (tensor<1x32xi32, #AL>) -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + + %a_ = tt.load %a_ptr_init, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> + %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> + %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return +} +} + +// ----- + +// TODO: MCast is not supported yet +//// 4 warps, TMA Load +//// matmul: 128x32 @ 32x128 -> 128x128 +//#C = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1]}> +//#SA = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +//#SB = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [2, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +//#BA = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> +//#BB = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [2, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +//// C-HECK: func @matmul_loop +//// C-HECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 +//// C-HECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 +//// C-HECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 +//// C-HECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 +//// C-HECK: %[[MBARRIER_AB:.*]] = triton_nvidia_gpu.alloc_mbarrier {count = 1 : i32} +//// C-HECK: %[[EMPTY_BARRIER_B:.*]] = triton_nvidia_gpu.alloc_mbarrier {count = 2 : i32} +//// C-HECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor +//// C-HECK: %[[MBARRIER_AB0:.*]] = triton_nvidia_gpu.extract_mbarrier %[[MBARRIER_AB]][%c0_i32] +//// C-HECK: triton_nvidia_gpu.mbarrier_arrive %[[MBARRIER_AB0]] +//// C-HECK: %[[A0BUFFER:.*]] = triton_nvidia_gpu.insert_slice_async_v2 {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[MBARRIER_AB0]] +//// C-HECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +//// C-HECK: %[[EMPTY_BARRIER_B0:.*]] = triton_nvidia_gpu.extract_mbarrier %[[EMPTY_BARRIER_B]][%c0_i32] +//// C-HECK: triton_nvidia_gpu.mbarrier_wait %[[EMPTY_BARRIER_B0]], %true +//// C-HECK: %[[B0BUFFER:.*]] = triton_nvidia_gpu.insert_slice_async_v2 {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[MBARRIER_AB0]] +//// C-HECK: %[[MBARRIER_AB1:.*]] = triton_nvidia_gpu.extract_mbarrier %[[MBARRIER_AB]][%c1_i32] +//// C-HECK: triton_nvidia_gpu.mbarrier_arrive %[[MBARRIER_AB1]] +//// C-HECK: %[[A1BUFFER:.*]] = triton_nvidia_gpu.insert_slice_async_v2 {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[MBARRIER_AB1]] +//// C-HECK: %[[B1BUFFER:.*]] = triton_nvidia_gpu.insert_slice_async_v2 {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[MBARRIER_AB1]] +//// C-HECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] +//// C-HECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] +//// C-HECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// // C-HECK: %[[MBARRIER_AB_ITER:.*]] = triton_nvidia_gpu.extract_mbarrier %[[MBARRIER_AB]][{{.*}}] +// // C-HECK: triton_nvidia_gpu.mbarrier_wait %[[MBARRIER_AB_ITER]], {{.*}} +// // C-HECK: triton_nvidia_gpu.dot_async %[[arg_a0]], %[[arg_b0]], {{.*}} +// // C-HECK: triton_nvidia_gpu.dot_wait {{.*}} +// // C-HECK: %[[EMPTY_BARRIER_B_ITER_ARRIVE:.*]] = triton_nvidia_gpu.extract_mbarrier %[[EMPTY_BARRIER_B]][{{.*}}] +// // C-HECK: triton_nvidia_gpu.mbarrier_arrive %[[EMPTY_BARRIER_B_ITER_ARRIVE]] +// // C-HECK: %[[MBARRIER_AB_NEXT_ITER:.*]] = triton_nvidia_gpu.extract_mbarrier %[[MBARRIER_AB]][{{.*}}] +// // C-HECK: %[[NEXT_A_BUFFER:.*]] = triton_nvidia_gpu.insert_slice_async_v2 {{.*}}, {{.*}}, {{.*}}, %[[MBARRIER_AB_NEXT_ITER]] +// // C-HECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][{{.*}}, 0, 0] +// // C-HECK: %[[EMPTY_BARRIER_B_ITER_WAIT:.*]] = triton_nvidia_gpu.extract_mbarrier %[[EMPTY_BARRIER_B]][{{.*}}] +// // C-HECK: triton_nvidia_gpu.mbarrier_wait %[[EMPTY_BARRIER_B_ITER_WAIT]], {{.*}} +// // C-HECK: %[[NEXT_B_BUFFER:.*]] = triton_nvidia_gpu.insert_slice_async_v2 {{.*}}, {{.*}}, {{.*}}, %[[MBARRIER_AB_NEXT_ITER]] +// // C-HECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][{{.*}}, 0, 0] +// // C-HECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} +//module attributes {"triton_gpu.num-ctas" = 2 : i32, "triton_gpu.num-warps" = 4 : i32} { +// tt.func @matmul_loop(%lb : index, %ub : index, %step : index, +// %A : !tt.ptr {tt.divisibility = 16 : i32}, +// %B : !tt.ptr {tt.divisibility = 16 : i32}) -> (!tt.ptr, 1>, !tt.ptr, 1>, tensor<128x128xf32, #C>) { +// %c0 = arith.constant 0 : i32 +// %c32_i32 = arith.constant 32 : i32 +// %c1 = arith.constant 1 : i64 +// %c32 = arith.constant 32 : i64 +// %c128 = arith.constant 128 : i64 +// %a_tileptr_init = tt.make_tensor_ptr %A, [%c128, %c32], [%c32, %c1], [%c0, %c0] { order = array } : !tt.ptr, 1> +// %b_tileptr_init = tt.make_tensor_ptr %B, [%c32, %c128], [%c1, %c32], [%c0, %c0] { order = array } : !tt.ptr, 1> +// %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> +// +// %res:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_tileptr = %a_tileptr_init, %b_tileptr = %b_tileptr_init, %prev_c = %c_init) -> (!tt.ptr, 1>, !tt.ptr, 1>, tensor<128x128xf32, #C>) { +// %a = tt.load %a_tileptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<128x32xf16, #BA> +// %b = tt.load %b_tileptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<32x128xf16, #BB> +// +// %sa = triton_gpu.convert_layout %a : (tensor<128x32xf16, #BA>) -> tensor<128x32xf16, #SA> +// %sb = triton_gpu.convert_layout %b : (tensor<32x128xf16, #BB>) -> tensor<32x128xf16, #SB> +// %c = tt.dot %sa, %sb, %prev_c {allowTF32 = true} : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C> +// +// %a_tileptr_next = tt.advance %a_tileptr, [%c0, %c32_i32] : !tt.ptr, 1> +// %b_tileptr_next = tt.advance %b_tileptr, [%c32_i32, %c0] : !tt.ptr, 1> +// +// scf.yield %a_tileptr_next, %b_tileptr_next, %c : !tt.ptr, 1>, !tt.ptr, 1>, tensor<128x128xf32, #C> +// } +// tt.return %res#0, %res#1, %res#2 : !tt.ptr, 1>, !tt.ptr, 1>, tensor<128x128xf32, #C> +// } +//} diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 7e49a2f2a462..626e7bdb1c85 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -2,12 +2,12 @@ // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 1]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 1]}> #ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> #BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> #BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> -#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> @@ -32,21 +32,23 @@ // CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] // CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] // CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] // CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}} -// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]] -// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] +// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] // CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0] // CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] +// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] +// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_3]] +// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]] +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]] tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -106,20 +108,22 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] // CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] // CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] // CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}} -// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]] -// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] +// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] // CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0] // CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] +// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] +// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_3]] +// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]] +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]] tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C>{ @@ -176,17 +180,19 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 1 : i32} // CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] // CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}} -// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]] -// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] +// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] // CHECK: triton_gpu.async_wait {num = 1 : i32} // CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] -// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] +// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] +// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_3]] +// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]] +// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]] tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { diff --git a/test/TritonGPU/materialize-load-store.mlir b/test/TritonGPU/materialize-load-store.mlir new file mode 100644 index 000000000000..65ca0e6c65a7 --- /dev/null +++ b/test/TritonGPU/materialize-load-store.mlir @@ -0,0 +1,63 @@ +// RUN: ENABLE_TMA=1 triton-opt %s -split-input-file -triton-nvidia-gpu-materialize-load-store=compute-capability=90 -canonicalize | FileCheck %s + +// CHECK-LABEL: @matmul_loop +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + tt.func @matmul_loop(%A : !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<64x16xf16, #blocked>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i64 + %c16 = arith.constant 16 : i64 + %c64 = arith.constant 64 : i64 + // CHECK: %[[TENSOR_PTR:.*]] = tt.make_tensor_ptr + %a_tileptr_init = tt.make_tensor_ptr %A, [%c64, %c16], [%c16, %c1], [%c0, %c0] { order = array } : !tt.ptr, 1> + // CHECK: %[[BUFFER:.*]] = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared> + // CHECK: %[[MBAR:.*]] = triton_nvidia_gpu.alloc_mbarrier {count = 1 : i32} : !tt.ptr + // CHECK: triton_nvidia_gpu.mbarrier_arrive %[[MBAR]], %{{.*}} {operand_segment_sizes = array, trackAsyncOp = false, txCount = 2048 : i32} : !tt.ptr, i1 + // CHECK: %[[INSERT:.*]] = triton_nvidia_gpu.insert_slice_async_v2 %[[TENSOR_PTR]], %[[BUFFER]], %{{.*}}, %[[MBAR]] + // CHECK: %[[EXT:.*]] = triton_gpu.extract_slice %[[INSERT]][0, 0, 0] [1, 64, 16] [1, 1, 1] : tensor<1x64x16xf16, #shared> to tensor<64x16xf16, #shared> + // CHECK: triton_nvidia_gpu.mbarrier_wait %[[MBAR]], %false : + // CHECK: %[[CVT:.*]] = triton_gpu.convert_layout %[[EXT]] : (tensor<64x16xf16, #shared>) -> tensor<64x16xf16, #blocked> + // CHECK: tt.return %[[CVT]] : tensor<64x16xf16, #blocked> + %res = tt.load %a_tileptr_init {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<64x16xf16, #blocked> + tt.return %res : tensor<64x16xf16, #blocked> + } +} + +// CHECK-LABEL: matmul_no_scf + +#blockedA0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blockedB0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blockedA1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blockedB1 = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16, 16]}> +#sharedA = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#sharedB = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + tt.func public @matmul_no_scf(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<64x16xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %0 = arith.extsi %arg3 : i32 to i64 + %1 = arith.extsi %arg5 : i32 to i64 + %2 = arith.extsi %arg6 : i32 to i64 + %3 = tt.make_tensor_ptr %arg0, [%0, %1], [%2, %c1_i64], [%c0_i32, %c0_i32] {order = array} : , 1> + %4 = arith.extsi %arg4 : i32 to i64 + %5 = arith.extsi %arg7 : i32 to i64 + %6 = tt.make_tensor_ptr %arg1, [%1, %4], [%c1_i64, %5], [%c0_i32, %c0_i32] {order = array} : , 1> + // CHECK: %[[LOADED_A:.*]] = triton_gpu.extract_slice + // CHECK: %[[LOADED_B:.*]] = triton_gpu.extract_slice + // CHECK-NOT: triton_gpu.convert_layout {{.*}}#shared{{.*}}->{{.*}}#blocked + // CHECK: tt.dot %[[LOADED_A]], %[[LOADED_B]] + %7 = tt.load %3 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<64x16xf16, #blockedA1> + %8 = tt.load %6 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x16xf16, #blockedB1> + %9 = triton_gpu.convert_layout %7 : (tensor<64x16xf16, #blockedA1>) -> tensor<64x16xf16, #sharedA> + %10 = triton_gpu.convert_layout %8 : (tensor<16x16xf16, #blockedB1>) -> tensor<16x16xf16, #sharedB> + %11 = tt.dot %9, %10, %cst {allowTF32 = true} : tensor<64x16xf16, #sharedA> * tensor<16x16xf16, #sharedB> -> tensor<64x16xf32, #mma> + %12 = triton_gpu.convert_layout %11 : (tensor<64x16xf32, #mma>) -> tensor<64x16xf32, #blockedA1> + %13 = arith.truncf %12 : tensor<64x16xf32, #blockedA1> to tensor<64x16xf16, #blockedA1> + %14 = arith.extsi %arg8 : i32 to i64 + %15 = tt.make_tensor_ptr %arg2, [%0, %4], [%14, %c1_i64], [%c0_i32, %c0_i32] {order = array} : , 1> + tt.store %15, %13 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<64x16xf16, #blockedA1> + tt.return + } +} diff --git a/test/TritonGPU/rewrite-tensor-pointer.mlir b/test/TritonGPU/rewrite-tensor-pointer.mlir new file mode 100644 index 000000000000..23eddb24b536 --- /dev/null +++ b/test/TritonGPU/rewrite-tensor-pointer.mlir @@ -0,0 +1,121 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-rewrite-tensor-pointer | FileCheck %s + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + tt.func public @matmul_kernel_0d1d2d3d456d7d8c9c10d11c121314c(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32) { + %c127_i32 = arith.constant 127 : i32 + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c1_i64 = arith.constant 1 : i64 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg5, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c127_i32 : i32 + %4 = arith.divsi %3, %c128_i32 : i32 + %5 = arith.muli %2, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %4, %7 : i32 + %9 = "triton_gpu.cmpi"(%8, %c8_i32) {predicate = 2 : i64} : (i32, i32) -> i1 + %10 = arith.select %9, %8, %c8_i32 : i32 + %11 = arith.remsi %0, %10 : i32 + %12 = arith.addi %7, %11 : i32 + %13 = arith.remsi %0, %5 : i32 + %14 = arith.divsi %13, %10 : i32 + %15 = arith.muli %12, %c128_i32 : i32 + %16 = arith.muli %14, %c128_i32 : i32 + %17 = arith.extsi %arg4 : i32 to i64 + %18 = arith.extsi %arg6 : i32 to i64 + %19 = arith.extsi %arg7 : i32 to i64 + // CHECK-NOT: tt.make_tensor_ptr + %20 = tt.make_tensor_ptr %arg0, [%17, %18], [%19, %c1_i64], [%15, %c0_i32] {order = array} : , 1> + %21 = arith.extsi %arg5 : i32 to i64 + %22 = arith.extsi %arg8 : i32 to i64 + // CHECK-NOT: tt.make_tensor_ptr + %23 = tt.make_tensor_ptr %arg1, [%18, %21], [%c1_i64, %22], [%c0_i32, %16] {order = array} : , 1> + %24:3 = scf.for %arg11 = %c0_i32 to %arg6 step %c64_i32 iter_args(%arg12 = %cst, %arg13 = %20, %arg14 = %23) -> (tensor<128x128xf32, #blocked>, !tt.ptr, 1>, !tt.ptr, 1>) : i32 { + // CHECK: tt.load %{{.*}}, %{{.*}}, %{{.*}} {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 2 : i32} : tensor<128x64xf16, + %28 = tt.load %arg13 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 2 : i32} : !tt.ptr, 1> -> tensor<128x64xf16, #blocked> + // CHECK: tt.load %{{.*}}, %{{.*}}, %{{.*}} {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 2 : i32} : tensor<64x128xf16, + %29 = tt.load %arg14 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 2 : i32} : !tt.ptr, 1> -> tensor<64x128xf16, #blocked1> + %30 = triton_gpu.convert_layout %28 : (tensor<128x64xf16, #blocked>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> + %31 = triton_gpu.convert_layout %29 : (tensor<64x128xf16, #blocked1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> + %32 = triton_gpu.convert_layout %arg12 : (tensor<128x128xf32, #blocked>) -> tensor<128x128xf32, #blocked2> + %33 = tt.dot %30, %31, %32 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x128xf32, #blocked2> + %34 = triton_gpu.convert_layout %33 : (tensor<128x128xf32, #blocked2>) -> tensor<128x128xf32, #blocked> + // CHECK-NOT: tt.advance + %35 = tt.advance %arg13, [%c0_i32, %c64_i32] : , 1> + // CHECK-NOT: tt.advance + %36 = tt.advance %arg14, [%c64_i32, %c0_i32] : , 1> + scf.yield %34, %35, %36 : tensor<128x128xf32, #blocked>, !tt.ptr, 1>, !tt.ptr, 1> + } + %25 = arith.truncf %24#0 : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> + %26 = arith.extsi %arg10 : i32 to i64 + %27 = tt.make_tensor_ptr %arg3, [%17, %21], [%26, %c1_i64], [%15, %16] {order = array} : , 1> + // CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32} : tensor<128x128xf16, #blocked> + tt.store %27, %25 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128x128xf16, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @if_for_if(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<64x1xf32, #blocked> + %c63_i32 = arith.constant 63 : i32 + %c-16_i32 = arith.constant -16 : i32 + %c132_i32 = arith.constant 132 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c63_i32 : i32 + %2 = arith.divsi %1, %c64_i32 : i32 + %3 = arith.muli %0, %c64_i32 : i32 + %4 = arith.extsi %arg3 : i32 to i64 + %5 = arith.extsi %arg4 : i32 to i64 + %6 = arith.extsi %arg5 : i32 to i64 + // CHECK-NOT: tt.make_tensor_ptr + %7 = tt.make_tensor_ptr %arg0, [%4, %5], [%6, %c1_i64], [%3, %c0_i32] {order = array} : , 1> + %8 = "triton_gpu.cmpi"(%2, %c132_i32) <{predicate = 5 : i64}> : (i32, i32) -> i1 + scf.if %8 { + %9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked1> + %10 = tt.splat %arg7 : (i32) -> tensor<64x1xi32, #blocked> + %11 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked> + %12 = scf.for %arg8 = %0 to %2 step %c132_i32 iter_args(%arg9 = %7) -> (!tt.ptr, 1>) : i32 { + %13 = "triton_gpu.cmpi"(%arg8, %c132_i32) <{predicate = 5 : i64}> : (i32, i32) -> i1 + %14 = scf.if %13 -> (!tt.ptr, 1>) { + %25 = arith.subi %arg8, %0 : i32 + %26 = arith.muli %25, %c64_i32 : i32 + // CHECK-NOT: tt.advance + %27 = tt.advance %arg9, [%26, %c-16_i32] : , 1> + scf.yield %27 : !tt.ptr, 1> + } else { + scf.yield %arg9 : !tt.ptr, 1> + } + %15 = arith.muli %arg8, %c64_i32 : i32 + %16 = tt.splat %15 : (i32) -> tensor<64xi32, #blocked1> + %17 = arith.addi %9, %16 : tensor<64xi32, #blocked1> + %18 = triton_gpu.convert_layout %17 : (tensor<64xi32, #blocked1>) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %19 = tt.expand_dims %18 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64x1xi32, #blocked2> + %20 = triton_gpu.convert_layout %19 : (tensor<64x1xi32, #blocked2>) -> tensor<64x1xi32, #blocked> + %21 = arith.muli %20, %10 : tensor<64x1xi32, #blocked> + %22 = tt.addptr %11, %21 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %23 = triton_gpu.convert_layout %22 : (tensor<64x1x!tt.ptr, #blocked>) -> tensor<64x1x!tt.ptr, #blocked> + %24 = triton_gpu.convert_layout %cst : (tensor<64x1xf32, #blocked>) -> tensor<64x1xf32, #blocked> + tt.store %23, %24 {cache = 1 : i32, evict = 1 : i32} : tensor<64x1xf32, #blocked> + scf.yield %14 : !tt.ptr, 1> + } + } + tt.return + } +} diff --git a/test/TritonGPU/stream-pipeline.mlir b/test/TritonGPU/stream-pipeline.mlir new file mode 100644 index 000000000000..e6ab4e5df012 --- /dev/null +++ b/test/TritonGPU/stream-pipeline.mlir @@ -0,0 +1,558 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-stream-pipeline -canonicalize | FileCheck %s + +// 4 warps +// matmul: 128x32 @ 32x128 -> 128x128 +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> +#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> +#BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> +#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> + +// CHECK: tt.func @matmul_loop +// Prologue +// CHECK: %[[A0_LOAD:.*]] = tt.load +// CHECK: %[[A0_SHARED:.*]] = triton_gpu.convert_layout %[[A0_LOAD]] +// CHECK: %[[B0_LOAD:.*]] = tt.load +// CHECK: %[[B0_SHARED:.*]] = triton_gpu.convert_layout %[[B0_LOAD]] +// Restructured for-loop +// CHECK: %[[FOR_OUTPUT:.*]]:{{.*}} = scf.for {{.*}} iter_args({{.*}}, %[[AC_ARG:.*]] = %[[A0_SHARED]], %[[BC_ARG:.*]] = %[[B0_SHARED]], %[[AN_ARG:.*]] = %{{.*}}, %[[BN_ARG:.*]] = %{{.*}}) +// CHECK: %[[AN_LOAD:.*]] = tt.load %[[AN_ARG]] +// CHECK: %[[BN_LOAD:.*]] = tt.load %[[BN_ARG]] +// CHECK: %[[AC_CVT:.*]] = triton_gpu.convert_layout %[[AC_ARG]] +// CHECK: %[[BC_CVT:.*]] = triton_gpu.convert_layout %[[BC_ARG]] +// CHECK: %[[BC_CVT_0:.*]] = arith.mulf %[[BC_CVT]], %{{.*}} +// CHECK: tt.dot %[[AC_CVT]], %[[BC_CVT_0]], {{.*}} +// CHECK: %[[AN_SHARED:.*]] = triton_gpu.convert_layout %[[AN_LOAD]] +// CHECK: %[[BN_SHARED:.*]] = triton_gpu.convert_layout %[[BN_LOAD]] +// CHECK: scf.yield {{.*}}, %[[AN_SHARED]], %[[BN_SHARED]], +// Epilogue +// CHECK: %[[AO_SHARED:.*]] = triton_gpu.convert_layout %[[FOR_OUTPUT]]#1 +// CHECK: %[[BO_SHARED:.*]] = triton_gpu.convert_layout %[[FOR_OUTPUT]]#2 +// CHECK: %[[BO_SHARED_0:.*]] = arith.mulf %[[BO_SHARED]], %{{.*}} +// CHECK-NEXT: tt.dot %[[AO_SHARED]], %[[BO_SHARED_0]], {{.*}} + +tt.func @matmul_loop(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { + // A ptrs + %a_ptr_splat = tt.splat %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : (tensor<32xi32, #ALs0>) -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : (tensor<1x32xi32, #AL>) -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %b_scale = arith.constant dense<4.> : tensor<32x128xf16, #B> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> + %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %b__ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> + %b_ = triton_gpu.convert_layout %b__ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> + %b = arith.mulf %b_, %b_scale: tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK: tt.func @matmul_loop_nested +// CHECK: scf.for +// Prologue +// CHECK: %[[A0_LOAD:.*]] = tt.load +// CHECK: %[[A0_SHARED:.*]] = triton_gpu.convert_layout %[[A0_LOAD]] +// CHECK: %[[B0_LOAD:.*]] = tt.load +// CHECK: %[[B0_SHARED:.*]] = triton_gpu.convert_layout %[[B0_LOAD]] +// Restructured for-loop +// CHECK: %[[FOR_OUTPUT:.*]]:{{.*}} = scf.for {{.*}} iter_args({{.*}}, %[[AC_ARG:.*]] = %[[A0_SHARED]], %[[BC_ARG:.*]] = %[[B0_SHARED]], %[[AN_ARG:.*]] = %{{.*}}, %[[BN_ARG:.*]] = %{{.*}}) +// CHECK: %[[AN_LOAD:.*]] = tt.load %[[AN_ARG]] +// CHECK: %[[BN_LOAD:.*]] = tt.load %[[BN_ARG]] +// CHECK: %[[AC_CVT:.*]] = triton_gpu.convert_layout %[[AC_ARG]] +// CHECK: %[[BC_CVT:.*]] = triton_gpu.convert_layout %[[BC_ARG]] +// CHECK: tt.dot %[[AC_CVT]], %[[BC_CVT]], {{.*}} +// CHECK: %[[AN_SHARED:.*]] = triton_gpu.convert_layout %[[AN_LOAD]] +// CHECK: %[[BN_SHARED:.*]] = triton_gpu.convert_layout %[[BN_LOAD]] +// CHECK: scf.yield {{.*}}, %[[AN_SHARED]], %[[BN_SHARED]], +// Epilogue +// CHECK: %[[AO_SHARED:.*]] = triton_gpu.convert_layout %[[FOR_OUTPUT]]#1 +// CHECK: %[[BO_SHARED:.*]] = triton_gpu.convert_layout %[[FOR_OUTPUT]]#2 +// CHECK-NEXT: tt.dot %[[AO_SHARED]], %[[BO_SHARED]], {{.*}} + +tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C>{ + + %c_start = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %loop1:1 = scf.for %iv0 = %lb to %ub step %step iter_args(%c_init = %c_start) -> (tensor<128x128xf32, #C>) { + // A ptrs + %a_ptr_splat = tt.splat %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : (tensor<32xi32, #ALs0>) -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : (tensor<1x32xi32, #AL>) -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop2:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> + %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> + %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + + scf.yield %loop2#2 : tensor<128x128xf32, #C> + } + tt.return %loop1#0 : tensor<128x128xf32, #C> +} + + +// CHECK: tt.func @matmul_loop_single_pipeline +// Prologue +// CHECK: %[[A0_LOAD:.*]] = tt.load +// CHECK: %[[A0_SHARED:.*]] = triton_gpu.convert_layout %[[A0_LOAD]] +// CHECK: %[[B0_LOAD:.*]] = tt.load +// CHECK: %[[B0_SHARED:.*]] = triton_gpu.convert_layout %[[B0_LOAD]] +// Restructured for-loop +// CHECK: %[[FOR_OUTPUT:.*]]:{{.*}} = scf.for {{.*}} iter_args({{.*}}, %[[BC_ARG:.*]] = %[[B0_SHARED]], %[[BN_ARG:.*]] = %{{.*}}) +// CHECK: %[[BN_LOAD:.*]] = tt.load %[[BN_ARG]] +// CHECK: %[[BC_DOT:.*]] = triton_gpu.convert_layout %[[BC_ARG]] +// CHECK: tt.dot %[[A0_SHARED]], %[[BC_DOT]], {{.*}} +// CHECK: %[[BN_SHARED:.*]] = triton_gpu.convert_layout %[[BN_LOAD]] +// CHECK: scf.yield {{.*}}, %[[BN_SHARED]], +// Epilogue +// CHECK: %[[BO_SHARED:.*]] = triton_gpu.convert_layout %[[FOR_OUTPUT]]#1 +// CHECK-NEXT: tt.dot %[[A0_SHARED]], %[[BO_SHARED]], {{.*}} + +tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { + // A ptrs + %a_ptr_splat = tt.splat %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : (tensor<32xi32, #ALs0>) -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : (tensor<1x32xi32, #AL>) -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + + %a_ = tt.load %a_ptr_init, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> + %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> + %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#1 : tensor<128x128xf32, #C> +} + +// CHECK: tt.func @lut_bmm_scalar +// Prologue +// CHECK: %[[A0_LOAD:.*]] = tt.load +// CHECK: %[[A0_SHARED:.*]] = triton_gpu.convert_layout %[[A0_LOAD]] +// Restructured for-loop +// CHECK: scf.for +// CHECK: %[[AN_LOAD:.*]] = tt.load +// CHECK: tt.load +// CHECK: tt.load +// CHECK: triton_gpu.convert_layout +// CHECK: triton_gpu.convert_layout +// CHECK: tt.dot +// CHECK: %[[AN_SHARED:.*]] = triton_gpu.convert_layout %[[AN_LOAD]] +// CHECK: scf.yield {{.*}}, %[[AN_SHARED]] +// Epilogue +// CHECK: tt.load +// CHECK: tt.load +// CHECK: triton_gpu.convert_layout +// CHECK: triton_gpu.convert_layout +// CHECK: tt.dot + +tt.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, + %76: index, + %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %75: !tt.ptr, + %78: tensor<16x16xi32, #AL> {tt.constancy=16: i32, tt.divisibility=16: i32}, + %60: tensor<16x16x!tt.ptr, #BL> {tt.divisibility=16: i32, tt.contiguity=16 : i32}) -> tensor<16x16xf32, #C>{ + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #C> + %c4_i32 = arith.constant 4 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i32 = arith.constant 1 : i32 + %79:3 = scf.for %arg18 = %c0 to %76 step %c1 iter_args(%arg19 = %cst, %arg20 = %49, %arg21 = %75) -> (tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr) { + %82 = tt.load %arg20 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #AL> + %83 = tt.load %arg21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : i64 + %84 = arith.muli %77, %83 : i64 + %85 = tt.splat %84 : (i64) -> tensor<16x16xi64, #BL> + %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> + %87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL> + %88 = triton_gpu.convert_layout %82 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A> + %89 = triton_gpu.convert_layout %87 : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #B> + %90 = tt.dot %88, %89, %arg19 {allowTF32 = true} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> + %92 = tt.addptr %arg21, %c1_i32 : !tt.ptr, i32 + scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr + } + tt.return %79#0 : tensor<16x16xf32, #C> +} + +// CHECK: tt.func @lut_bmm_vector +// Prologue +// CHECK: %[[A0_LOAD:.*]] = tt.load +// CHECK: %[[A0_SHARED:.*]] = triton_gpu.convert_layout %[[A0_LOAD]] +// Restructured for-loop +// CHECK: scf.for +// CHECK: %[[AN_LOAD:.*]] = tt.load +// CHECK: tt.load +// CHECK: tt.load +// CHECK: triton_gpu.convert_layout +// CHECK: triton_gpu.convert_layout +// CHECK: tt.dot +// CHECK: %[[AN_SHARED:.*]] = triton_gpu.convert_layout %[[AN_LOAD]] +// CHECK: scf.yield {{.*}}, %[[AN_SHARED]] +// Epilogue +// CHECK: tt.load +// CHECK: tt.load +// CHECK: triton_gpu.convert_layout +// CHECK: triton_gpu.convert_layout +// CHECK: tt.dot +tt.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, + %76: index, + %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %75: tensor<16x!tt.ptr, #BLs1>, + %78: tensor<16x16xi32, #AL> {tt.constancy=16: i32, tt.divisibility=16: i32}, + %60: tensor<16x16x!tt.ptr, #BL> {tt.divisibility=16: i32, tt.contiguity=16 : i32}) -> tensor<16x16xf32, #C>{ + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #C> + %c4_i32 = arith.constant 4 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i32 = arith.constant 1 : i32 + %c1_i32_splat = tt.splat %c1_i32 : (i32) -> tensor<16xi32, #BLs1> + %79:3 = scf.for %arg18 = %c0 to %76 step %c1 iter_args(%arg19 = %cst, %arg20 = %49, %arg21 = %75) -> (tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1>) { + %82 = tt.load %arg20 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #AL> + %83 = tt.load %arg21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16xi64, #BLs1> + %84 = tt.expand_dims %83 {axis=1: i32}: (tensor<16xi64, #BLs1>) -> tensor<16x1xi64, #BL> + %850 = tt.broadcast %84 : (tensor<16x1xi64, #BL>) -> tensor<16x16xi64, #BL> + %85 = arith.muli %77, %850 : tensor<16x16xi64, #BL> + %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> + %87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL> + %88 = triton_gpu.convert_layout %82 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A> + %89 = triton_gpu.convert_layout %87 : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #B> + %90 = tt.dot %88, %89, %arg19 {allowTF32 = true} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> + %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> + scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> + } + tt.return %79#0 : tensor<16x16xf32, #C> +} + +// CHECK: tt.func @post_load_inv +// CHECK: scf.for +// CHECK: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: triton_gpu.convert_layout +// CHECK-NEXT: triton_gpu.convert_layout +// CHECK-NEXT: tt.dot + +tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}, + %arg4: i32 {tt.divisibility = 16 : i32}, + %arg5: i32 {tt.divisibility = 16 : i32}, + %arg6: i32 {tt.divisibility = 16 : i32}, + %arg7: i32 {tt.divisibility = 16 : i32}, + %arg8: i32 {tt.divisibility = 16 : i32}) -> tensor<32x32xf32, #C> { + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c1_i32 = arith.constant 1 : i32 + %c32_i32 = arith.constant 32 : i32 + %84 = arith.constant 900 : index + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #C> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #AL> + %50 = tt.splat %arg3 : (i32) -> tensor<1x32xi32, #AL> + %59 = tt.splat %arg0 : (!tt.ptr) -> tensor<32x32x!tt.ptr, #AL> + %81 = tt.splat %arg1 : (!tt.ptr) -> tensor<32x32x!tt.ptr, #AL> + %66 = tt.splat %arg4 : (i32) -> tensor<32x1xi32, #AL> + %60 = tt.splat %arg2 : (!tt.ptr) -> tensor<32x32x!tt.ptr, #AL> + %82 = tt.splat %arg2 : (!tt.ptr) -> tensor<32x32x!tt.ptr, #AL> + %85:3 = scf.for %arg9 = %c0_index to %84 step %c1_index iter_args(%arg10 = %cst, %arg11 = %59, %arg12 = %81) -> (tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>) { + %130 = arith.index_cast %arg9 : index to i32 + %107 = arith.muli %130, %c32_i32 : i32 + %108 = arith.subi %arg5, %107 : i32 + %109 = tt.splat %108 : (i32) -> tensor<1x32xi32, #AL> + %110 = "triton_gpu.cmpi"(%50, %109) <{predicate = 2 : i64}> : (tensor<1x32xi32, #AL>, tensor<1x32xi32, #AL>) -> tensor<1x32xi1, #AL> + %111 = tt.broadcast %110 : (tensor<1x32xi1, #AL>) -> tensor<32x32xi1, #AL> + %112 = tt.load %arg11, %111, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL> + %113 = tt.splat %108 : (i32) -> tensor<32x1xi32, #AL> + %114 = "triton_gpu.cmpi"(%66, %113) <{predicate = 2 : i64}> : (tensor<32x1xi32, #AL>, tensor<32x1xi32, #AL>) -> tensor<32x1xi1, #AL> + %115 = tt.broadcast %114 : (tensor<32x1xi1, #AL>) -> tensor<32x32xi1, #AL> + %116 = tt.load %arg12, %115, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL> + %117 = triton_gpu.convert_layout %112 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> + %118 = triton_gpu.convert_layout %116 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> + %119 = tt.dot %117, %118, %arg10 {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %131 = arith.index_cast %arg9 : index to i32 + %120 = arith.addi %131, %c1_i32 : i32 + %121 = arith.muli %120, %c32_i32 : i32 + %122 = tt.splat %121 : (i32) -> tensor<32x32xi32, #AL> + %123 = tt.addptr %60, %122 : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> + %124 = arith.muli %121, %arg7 : i32 + %125 = tt.splat %124 : (i32) -> tensor<32x32xi32, #AL> + %126 = tt.addptr %82, %125 : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> + scf.yield %119, %123, %126 : tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL> + } + tt.return %85#0 : tensor<32x32xf32, #C> +} + +// No stream pipeline +// CHECK: tt.func @cross_iter_dep +// CHECK: scf.for +// CHECK: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: tt.return +tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}, + %arg4: i32 {tt.divisibility = 16 : i32}, + %arg5: i32 {tt.divisibility = 16 : i32}, + %arg6: i32 {tt.divisibility = 16 : i32}, + %arg7: i32 {tt.divisibility = 16 : i32}, + %arg8: i32 {tt.divisibility = 16 : i32}) -> tensor<32x32xf32, #C> { + %c0_i32 = arith.constant 0 : index + %118 = arith.constant 32 : index + %c1_i32 = arith.constant 1 : index + %c2_i32 = arith.constant 2 : i32 + %c32_i32 = arith.constant 32 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #C> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #AL> + %78 = tt.splat %arg0 : (!tt.ptr) -> tensor<32x32x!tt.ptr, #AL> + %110 = tt.splat %arg0 : (!tt.ptr) -> tensor<32x32x!tt.ptr, #AL> + %112 = tt.splat %arg1 : (!tt.ptr) -> tensor<32x32x!tt.ptr, #AL> + %113 = tt.splat %arg1 : (!tt.ptr) -> tensor<32x32x!tt.ptr, #AL> + %116 = tt.splat %arg2 : (!tt.ptr) -> tensor<32x32x!tt.ptr, #AL> + %65 = tt.splat %arg3 : (i32) -> tensor<1x32xi32, #AL> + %88 = tt.splat %arg4 : (i32) -> tensor<32x1xi32, #AL> + %80 = tt.splat %arg2 : (!tt.ptr) -> tensor<32x32x!tt.ptr, #AL> + %119:5 = scf.for %arg9 = %c0_i32 to %118 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %78, %arg12 = %110, %arg13 = %113, %arg14 = %116) -> (tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>) { + %161 = arith.index_cast %arg9 : index to i32 + %141 = arith.muli %161, %c32_i32 : i32 + %142 = arith.subi %arg5, %141 : i32 + %143 = tt.splat %142 : (i32) -> tensor<1x32xi32, #AL> + %144 = "triton_gpu.cmpi"(%65, %143) <{predicate = 2 : i64}> : (tensor<1x32xi32, #AL>, tensor<1x32xi32, #AL>) -> tensor<1x32xi1, #AL> + %145 = tt.broadcast %144 : (tensor<1x32xi1, #AL>) -> tensor<32x32xi1, #AL> + %146 = tt.load %arg11, %145, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL> + %147 = tt.splat %142 : (i32) -> tensor<32x1xi32, #AL> + %148 = "triton_gpu.cmpi"(%88, %147) <{predicate = 2 : i64}> : (tensor<32x1xi32, #AL>, tensor<32x1xi32, #AL>) -> tensor<32x1xi1, #AL> + %149 = tt.broadcast %148 : (tensor<32x1xi1, #AL>) -> tensor<32x32xi1, #AL> + %150 = tt.load %arg12, %149, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL> + %151 = triton_gpu.convert_layout %146 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> + %152 = triton_gpu.convert_layout %150 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> + %153 = tt.dot %151, %152, %arg10 {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %162 = arith.index_cast %arg9 : index to i32 + %154 = arith.addi %162, %c2_i32 : i32 + %155 = arith.muli %154, %c32_i32 : i32 + %156 = tt.splat %155 : (i32) -> tensor<32x32xi32, #AL> + %157 = tt.addptr %80, %156 : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> + %158 = arith.muli %155, %arg7 : i32 + %159 = tt.splat %158 : (i32) -> tensor<32x32xi32, #AL> + %160 = tt.addptr %112, %159 : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> + scf.yield %153, %arg13, %arg14, %157, %160 : tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL> + } + tt.return %119#0 : tensor<32x32xf32, #C> +} + +// CHECK: tt.func @matmul_mixed_kernel +// Prologue +// CHECK: %[[A0_LOAD:.*]] = tt.load +// CHECK: %[[A0_SHARED:.*]] = triton_gpu.convert_layout %[[A0_LOAD]] +// Restructured for-loop +// CHECK: scf.for +// CHECK: %[[AN_LOAD:.*]] = tt.load +// CHECK: tt.load +// CHECK: triton_gpu.convert_layout +// CHECK: tt.dot +// CHECK: %[[AN_SHARED:.*]] = triton_gpu.convert_layout %[[AN_LOAD]] +// CHECK: scf.yield {{.*}}, %[[AN_SHARED]] +// Epilogue +// CHECK: tt.load +// CHECK: triton_gpu.convert_layout +// CHECK: tt.dot + +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [2, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 2], warpsPerCTA = [2, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> +tt.func @matmul_mixed_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf16, #blocked> + %cst_0 = arith.constant dense<32> : tensor<64x32xi32, #blocked1> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked2> + %c31_i32 = arith.constant 31 : i32 + %c63_i32 = arith.constant 63 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c63_i32 : i32 + %2 = arith.divsi %1, %c64_i32 : i32 + %3 = arith.addi %arg4, %c31_i32 : i32 + %4 = arith.divsi %3, %c32_i32 : i32 + %5 = arith.divsi %0, %4 : i32 + %6 = arith.subi %2, %5 : i32 + %7 = "triton_gpu.cmpi"(%6, %c1_i32) <{predicate = 2 : i64}> : (i32, i32) -> i1 + %8 = arith.select %7, %6, %c1_i32 : i32 + %9 = arith.remsi %0, %8 : i32 + %10 = arith.addi %5, %9 : i32 + %11 = arith.remsi %0, %4 : i32 + %12 = arith.divsi %11, %8 : i32 + %13 = arith.muli %10, %c64_i32 : i32 + %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %17 = tt.splat %13 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %18 = tt.splat %13 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %19 = tt.splat %13 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %20 = arith.addi %17, %14 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %21 = arith.addi %18, %15 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %22 = arith.addi %19, %16 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %23 = tt.splat %arg3 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %24 = arith.remsi %20, %23 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %25 = arith.muli %12, %c32_i32 : i32 + %26 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %27 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %28 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %29 = tt.splat %25 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %30 = tt.splat %25 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %31 = tt.splat %25 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %32 = arith.addi %29, %26 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %33 = arith.addi %30, %27 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %34 = arith.addi %31, %28 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %35 = tt.splat %arg4 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %36 = arith.remsi %32, %35 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %37 = tt.expand_dims %24 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1> + %38 = tt.splat %arg6 : (i32) -> tensor<64x1xi32, #blocked1> + %39 = arith.muli %37, %38 : tensor<64x1xi32, #blocked1> + %40 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %41 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %42 = tt.expand_dims %40 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x32xi32, #blocked1> + %43 = tt.expand_dims %41 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x32xi32, #blocked1> + %44 = tt.broadcast %39 : (tensor<64x1xi32, #blocked1>) -> tensor<64x32xi32, #blocked1> + %45 = tt.broadcast %42 : (tensor<1x32xi32, #blocked1>) -> tensor<64x32xi32, #blocked1> + %46 = arith.addi %44, %45 : tensor<64x32xi32, #blocked1> + %47 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x32x!tt.ptr, #blocked1> + %48 = tt.addptr %47, %46 : tensor<64x32x!tt.ptr, #blocked1>, tensor<64x32xi32, #blocked1> + %49 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %50 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %51 = tt.expand_dims %49 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<32x1xi32, #blocked2> + %52 = tt.expand_dims %50 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<32x1xi32, #blocked2> + %53 = tt.splat %arg7 : (i32) -> tensor<32x1xi32, #blocked2> + %54 = arith.muli %51, %53 : tensor<32x1xi32, #blocked2> + %55 = tt.expand_dims %36 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x32xi32, #blocked2> + %56 = tt.broadcast %54 : (tensor<32x1xi32, #blocked2>) -> tensor<32x32xi32, #blocked2> + %57 = tt.broadcast %55 : (tensor<1x32xi32, #blocked2>) -> tensor<32x32xi32, #blocked2> + %58 = arith.addi %56, %57 : tensor<32x32xi32, #blocked2> + %59 = tt.splat %arg1 : (!tt.ptr) -> tensor<32x32x!tt.ptr, #blocked2> + %60 = tt.addptr %59, %58 : tensor<32x32x!tt.ptr, #blocked2>, tensor<32x32xi32, #blocked2> + %61 = arith.addi %arg5, %c31_i32 : i32 + %62 = arith.divsi %61, %c32_i32 : i32 + %63 = tt.fp_to_fp %cst_1 : tensor<64x32xf32, #blocked1> -> tensor<64x32xf8E4M3FNUZ, #blocked1> + %64 = arith.muli %arg7, %c32_i32 : i32 + %65 = tt.splat %64 : (i32) -> tensor<32x32xi32, #blocked2> + %66:3 = scf.for %arg9 = %c0_i32 to %62 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %48, %arg12 = %60) -> (tensor<64x32xf16, #blocked>, tensor<64x32x!tt.ptr, #blocked1>, tensor<32x32x!tt.ptr, #blocked2>) : i32 { + %86 = arith.muli %arg9, %c32_i32 : i32 + %87 = arith.subi %arg5, %86 : i32 + %88 = tt.splat %87 : (i32) -> tensor<1x32xi32, #blocked1> + %89 = "triton_gpu.cmpi"(%43, %88) <{predicate = 2 : i64}> : (tensor<1x32xi32, #blocked1>, tensor<1x32xi32, #blocked1>) -> tensor<1x32xi1, #blocked1> + %90 = tt.broadcast %89 : (tensor<1x32xi1, #blocked1>) -> tensor<64x32xi1, #blocked1> + %91 = tt.load %arg11, %90, %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf8E4M3FNUZ, #blocked1> + %92 = tt.splat %87 : (i32) -> tensor<32x1xi32, #blocked2> + %93 = "triton_gpu.cmpi"(%52, %92) <{predicate = 2 : i64}> : (tensor<32x1xi32, #blocked2>, tensor<32x1xi32, #blocked2>) -> tensor<32x1xi1, #blocked2> + %94 = tt.broadcast %93 : (tensor<32x1xi1, #blocked2>) -> tensor<32x32xi1, #blocked2> + %95 = tt.load %arg12, %94, %cst_2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16, #blocked2> + %96 = tt.fp_to_fp %91 : tensor<64x32xf8E4M3FNUZ, #blocked1> -> tensor<64x32xf16, #blocked1> + %97 = triton_gpu.convert_layout %96 : (tensor<64x32xf16, #blocked1>) -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %98 = triton_gpu.convert_layout %95 : (tensor<32x32xf16, #blocked2>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %99 = tt.dot %97, %98, %arg10 {allowTF32 = true} : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x32xf16, #blocked> + %100 = tt.addptr %arg11, %cst_0 : tensor<64x32x!tt.ptr, #blocked1>, tensor<64x32xi32, #blocked1> + %101 = tt.addptr %arg12, %65 : tensor<32x32x!tt.ptr, #blocked2>, tensor<32x32xi32, #blocked2> + scf.yield %99, %100, %101 : tensor<64x32xf16, #blocked>, tensor<64x32x!tt.ptr, #blocked1>, tensor<32x32x!tt.ptr, #blocked2> + } + %67 = tt.expand_dims %21 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64x1xi32, #blocked2> + %68 = tt.expand_dims %22 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64x1xi32, #blocked2> + %69 = tt.splat %arg8 : (i32) -> tensor<64x1xi32, #blocked2> + %70 = arith.muli %69, %67 : tensor<64x1xi32, #blocked2> + %71 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked2> + %72 = tt.addptr %71, %70 : tensor<64x1x!tt.ptr, #blocked2>, tensor<64x1xi32, #blocked2> + %73 = tt.expand_dims %33 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x32xi32, #blocked2> + %74 = tt.expand_dims %34 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x32xi32, #blocked2> + %75 = tt.broadcast %72 : (tensor<64x1x!tt.ptr, #blocked2>) -> tensor<64x32x!tt.ptr, #blocked2> + %76 = tt.broadcast %73 : (tensor<1x32xi32, #blocked2>) -> tensor<64x32xi32, #blocked2> + %77 = tt.addptr %75, %76 : tensor<64x32x!tt.ptr, #blocked2>, tensor<64x32xi32, #blocked2> + %78 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked2> + %79 = "triton_gpu.cmpi"(%68, %78) <{predicate = 2 : i64}> : (tensor<64x1xi32, #blocked2>, tensor<64x1xi32, #blocked2>) -> tensor<64x1xi1, #blocked2> + %80 = tt.splat %arg4 : (i32) -> tensor<1x32xi32, #blocked2> + %81 = "triton_gpu.cmpi"(%74, %80) <{predicate = 2 : i64}> : (tensor<1x32xi32, #blocked2>, tensor<1x32xi32, #blocked2>) -> tensor<1x32xi1, #blocked2> + %82 = tt.broadcast %79 : (tensor<64x1xi1, #blocked2>) -> tensor<64x32xi1, #blocked2> + %83 = tt.broadcast %81 : (tensor<1x32xi1, #blocked2>) -> tensor<64x32xi1, #blocked2> + %84 = arith.andi %82, %83 : tensor<64x32xi1, #blocked2> + %85 = triton_gpu.convert_layout %66#0 : (tensor<64x32xf16, #blocked>) -> tensor<64x32xf16, #blocked2> + tt.store %77, %85, %84 {cache = 1 : i32, evict = 1 : i32} : tensor<64x32xf16, #blocked2> + tt.return +} diff --git a/test/TritonGPU/wsdecomposing.mlir b/test/TritonGPU/wsdecomposing.mlir new file mode 100644 index 000000000000..059554a59195 --- /dev/null +++ b/test/TritonGPU/wsdecomposing.mlir @@ -0,0 +1,754 @@ +// RUN: triton-opt -split-input-file -triton-nvidia-gpu-ws-decomposing='compute-capability=90' %s 2>&1 | FileCheck %s +// XFAIL: * +// TODO: change this test to not check for a fixed IR. + +// Check if all opereations are labeled with appropriate attributes. +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32} { + // CHECK-LABEL: @simple_gemm + tt.func public @simple_gemm(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_0 = arith.constant dense<32> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant dense<32> : tensor<128x32xi32, #blocked1> + %c31_i32 = arith.constant 31 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c32_i32 = arith.constant 32 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.addi %arg3, %c127_i32 : i32 + %3 = arith.divsi %2, %c128_i32 : i32 + %4 = arith.addi %arg4, %c127_i32 : i32 + %5 = arith.divsi %4, %c128_i32 : i32 + %6 = arith.muli %5, %c8_i32 : i32 + %7 = arith.divsi %0, %6 : i32 + %8 = arith.muli %7, %c8_i32 : i32 + %9 = arith.subi %3, %8 : i32 + %10 = arith.cmpi slt, %9, %c8_i32 : i32 + %11 = arith.select %10, %9, %c8_i32 : i32 + %12 = arith.remsi %0, %11 : i32 + %13 = arith.addi %8, %12 : i32 + %14 = arith.remsi %0, %6 : i32 + %15 = arith.divsi %14, %11 : i32 + %16 = arith.muli %13, %c128_i32 : i32 + %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %23 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %24 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %25 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %26 = arith.addi %23, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.addi %24, %19 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %28 = arith.addi %25, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %29 = arith.muli %15, %c128_i32 : i32 + %30 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %31 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %32 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %33 = arith.addi %30, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %34 = arith.addi %31, %20 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %35 = arith.addi %32, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %36 = tt.splat %arg3 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %37 = tt.splat %arg3 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %38 = arith.remsi %26, %36 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %39 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %40 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %41 = arith.remsi %33, %39 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %42 = arith.muli %1, %c32_i32 : i32 + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %45 = tt.splat %42 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %46 = tt.splat %42 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %47 = arith.addi %45, %43 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %48 = arith.addi %46, %44 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %49 = tt.expand_dims %38 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> + %50 = tt.splat %arg6 : (i32) -> tensor<128x1xi32, #blocked1> + %51 = arith.muli %49, %50 : tensor<128x1xi32, #blocked1> + %52 = tt.expand_dims %47 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x32xi32, #blocked1> + %53 = tt.broadcast %51 : (tensor<128x1xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + %54 = tt.broadcast %52 : (tensor<1x32xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + %55 = arith.addi %53, %54 : tensor<128x32xi32, #blocked1> + %56 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x32x!tt.ptr, #blocked1> + %57 = tt.addptr %56, %55 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %58 = tt.expand_dims %48 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked> + %59 = tt.expand_dims %41 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %60 = tt.splat %arg7 : (i32) -> tensor<1x128xi32, #blocked> + %61 = arith.muli %59, %60 : tensor<1x128xi32, #blocked> + %62 = tt.broadcast %58 : (tensor<32x1xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %63 = tt.broadcast %61 : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %64 = arith.addi %62, %63 : tensor<32x128xi32, #blocked> + %65 = tt.splat %arg1 : (!tt.ptr) -> tensor<32x128x!tt.ptr, #blocked> + %66 = tt.addptr %65, %64 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %67 = arith.addi %arg5, %c31_i32 : i32 + %68 = arith.divsi %67, %c32_i32 : i32 + %69 = arith.index_cast %68 : i32 to index + %70:3 = scf.for %arg9 = %c0 to %69 step %c1 iter_args(%arg10 = %cst, %arg11 = %57, %arg12 = %66) -> (tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>) { + %89 = tt.load %arg11 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #blocked1> + %90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> + %91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> + %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> + %93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> + } + %71 = arith.truncf %70#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %72 = tt.expand_dims %27 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> + %73 = tt.splat %arg8 : (i32) -> tensor<128x1xi32, #blocked2> + %74 = arith.muli %72, %73 : tensor<128x1xi32, #blocked2> + %75 = tt.expand_dims %34 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2> + %76 = tt.broadcast %74 : (tensor<128x1xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %77 = tt.broadcast %75 : (tensor<1x128xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %78 = arith.addi %76, %77 : tensor<128x128xi32, #blocked2> + %79 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> + %80 = tt.addptr %79, %78 : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> + %81 = "triton_gpu.cmpi"(%28, %37) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %82 = tt.expand_dims %81 {axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> + %83 = "triton_gpu.cmpi"(%35, %40) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %84 = tt.expand_dims %83 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> + %85 = tt.broadcast %82 : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + %86 = tt.broadcast %84 : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + %87 = arith.andi %85, %86 : tensor<128x128xi1, #blocked2> + %88 = triton_gpu.convert_layout %71 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #blocked2> + tt.store %80, %88, %87 {cache = 1 : i32, evict = 1 : i32} : tensor<128x128xf16, #blocked2> + tt.return + + // CHECK-NEXT: %cst = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #mma> + // CHECK-NEXT: %cst_0 = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<32> : tensor<32x128xi32, #blocked> + // CHECK-NEXT: %cst_1 = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<32> : tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %c31_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 31 : i32 + // CHECK-NEXT: %c127_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 127 : i32 + // CHECK-NEXT: %c1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 1 : index + // CHECK-NEXT: %c0 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : index + // CHECK-NEXT: %c32_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 32 : i32 + // CHECK-NEXT: %c128_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 128 : i32 + // CHECK-NEXT: %c8_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 8 : i32 + // CHECK-NEXT: %0 = tt.get_program_id {async_agent = dense<[0, 1]> : vector<2xi32>, axis = 0 : i32} : i32 + // CHECK-NEXT: %1 = tt.get_program_id {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : i32 + // CHECK-NEXT: %2 = arith.addi %arg3, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %3 = arith.divsi %2, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %4 = arith.addi %arg4, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %5 = arith.divsi %4, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %6 = arith.muli %5, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %7 = arith.divsi %0, %6 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %8 = arith.muli %7, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %9 = arith.subi %3, %8 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %10 = arith.cmpi slt, %9, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %11 = arith.select %10, %9, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %12 = arith.remsi %0, %11 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %13 = arith.addi %8, %12 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %14 = arith.remsi %0, %6 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %15 = arith.divsi %14, %11 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %16 = arith.muli %13, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %17 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + // CHECK-NEXT: %18 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + // CHECK-NEXT: %19 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %20 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %21 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %22 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %23 = tt.splat %16 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + // CHECK-NEXT: %24 = tt.splat %16 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %25 = tt.splat %16 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %26 = arith.addi %23, %17 {async_agent = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + // CHECK-NEXT: %27 = arith.addi %24, %19 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %28 = arith.addi %25, %21 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %29 = arith.muli %15, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %30 = tt.splat %29 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + // CHECK-NEXT: %31 = tt.splat %29 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %32 = tt.splat %29 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %33 = arith.addi %30, %18 {async_agent = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + // CHECK-NEXT: %34 = arith.addi %31, %20 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %35 = arith.addi %32, %22 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %36 = tt.splat %arg3 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + // CHECK-NEXT: %37 = tt.splat %arg3 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %38 = arith.remsi %26, %36 {async_agent = dense<0> : vector<1xi32>, tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + // CHECK-NEXT: %39 = tt.splat %arg4 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + // CHECK-NEXT: %40 = tt.splat %arg4 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %41 = arith.remsi %33, %39 {async_agent = dense<0> : vector<1xi32>, tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + // CHECK-NEXT: %42 = arith.muli %1, %c32_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + // CHECK-NEXT: %43 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + // CHECK-NEXT: %44 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + // CHECK-NEXT: %45 = tt.splat %42 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + // CHECK-NEXT: %46 = tt.splat %42 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + // CHECK-NEXT: %47 = arith.addi %45, %43 {async_agent = dense<0> : vector<1xi32>} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + // CHECK-NEXT: %48 = arith.addi %46, %44 {async_agent = dense<0> : vector<1xi32>} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + // CHECK-NEXT: %49 = tt.expand_dims %38 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> + // CHECK-NEXT: %50 = tt.splat %arg6 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128x1xi32, #blocked1> + // CHECK-NEXT: %51 = arith.muli %49, %50 {async_agent = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked1> + // CHECK-NEXT: %52 = tt.expand_dims %47 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x32xi32, #blocked1> + // CHECK-NEXT: %53 = tt.broadcast %51 {async_agent = dense<0> : vector<1xi32>} : (tensor<128x1xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %54 = tt.broadcast %52 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x32xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %55 = arith.addi %53, %54 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %56 = tt.splat %arg0 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr) -> tensor<128x32x!tt.ptr, #blocked1> + // CHECK-NEXT: %57 = tt.addptr %56, %55 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %58 = tt.expand_dims %48 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked> + // CHECK-NEXT: %59 = tt.expand_dims %41 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + // CHECK-NEXT: %60 = tt.splat %arg7 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<1x128xi32, #blocked> + // CHECK-NEXT: %61 = arith.muli %59, %60 {async_agent = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked> + // CHECK-NEXT: %62 = tt.broadcast %58 {async_agent = dense<0> : vector<1xi32>} : (tensor<32x1xi32, #blocked>) -> tensor<32x128xi32, #blocked> + // CHECK-NEXT: %63 = tt.broadcast %61 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + // CHECK-NEXT: %64 = arith.addi %62, %63 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128xi32, #blocked> + // CHECK-NEXT: %65 = tt.splat %arg1 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr) -> tensor<32x128x!tt.ptr, #blocked> + // CHECK-NEXT: %66 = tt.addptr %65, %64 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + // CHECK-NEXT: %67 = arith.addi %arg5, %c31_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %68 = arith.divsi %67, %c32_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %69 = arith.index_cast %68 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 to index + // CHECK-NEXT: %70:3 = scf.for %arg9 = %c0 to %69 step %c1 iter_args(%arg10 = %cst, %arg11 = %57, %arg12 = %66) -> (tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>) { + // CHECK-NEXT: %89 = tt.load %arg11 {async_agent = dense<0> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #blocked1> + // CHECK-NEXT: %90 = tt.load %arg12 {async_agent = dense<0> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> + // CHECK-NEXT: %91 = triton_gpu.convert_layout %89 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> + // CHECK-NEXT: %92 = triton_gpu.convert_layout %90 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> + // CHECK-NEXT: %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + // CHECK-NEXT: %94 = tt.addptr %arg11, %cst_1 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %95 = tt.addptr %arg12, %cst_0 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + // CHECK-NEXT: scf.yield {async_agent = dense<[0, 1]> : vector<2xi32>} %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> + // CHECK-NEXT: } {async_agent = dense<[0, 1]> : vector<2xi32>} + // CHECK-NEXT: %71 = arith.truncf %70#0 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + // CHECK-NEXT: %72 = tt.expand_dims %27 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> + // CHECK-NEXT: %73 = tt.splat %arg8 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128x1xi32, #blocked2> + // CHECK-NEXT: %74 = arith.muli %72, %73 {async_agent = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked2> + // CHECK-NEXT: %75 = tt.expand_dims %34 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2> + // CHECK-NEXT: %76 = tt.broadcast %74 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + // CHECK-NEXT: %77 = tt.broadcast %75 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x128xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + // CHECK-NEXT: %78 = arith.addi %76, %77 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xi32, #blocked2> + // CHECK-NEXT: %79 = tt.splat %arg2 {async_agent = dense<1> : vector<1xi32>} : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> + // CHECK-NEXT: %80 = tt.addptr %79, %78 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> + // CHECK-NEXT: %81 = "triton_gpu.cmpi"(%28, %37) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %82 = tt.expand_dims %81 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> + // CHECK-NEXT: %83 = "triton_gpu.cmpi"(%35, %40) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %84 = tt.expand_dims %83 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> + // CHECK-NEXT: %85 = tt.broadcast %82 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + // CHECK-NEXT: %86 = tt.broadcast %84 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + // CHECK-NEXT: %87 = arith.andi %85, %86 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xi1, #blocked2> + // CHECK-NEXT: %88 = triton_gpu.convert_layout %71 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #blocked2> + // CHECK-NEXT: tt.store %80, %88, %87 {async_agent = dense<1> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32} : tensor<128x128xf16, #blocked2> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32} { + // CHECK-LABEL: @nested_for_gemm + tt.func public @nested_for_gemm(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_0 = arith.constant dense<32> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant dense<32> : tensor<128x32xi32, #blocked1> + %c31_i32 = arith.constant 31 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c32_i32 = arith.constant 32 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.addi %arg3, %c127_i32 : i32 + %3 = arith.divsi %2, %c128_i32 : i32 + %4 = arith.addi %arg4, %c127_i32 : i32 + %5 = arith.divsi %4, %c128_i32 : i32 + %6 = arith.muli %5, %c8_i32 : i32 + %7 = arith.divsi %0, %6 : i32 + %8 = arith.muli %7, %c8_i32 : i32 + %9 = arith.subi %3, %8 : i32 + %10 = arith.cmpi slt, %9, %c8_i32 : i32 + %11 = arith.select %10, %9, %c8_i32 : i32 + %12 = arith.remsi %0, %11 : i32 + %13 = arith.addi %8, %12 : i32 + %14 = arith.remsi %0, %6 : i32 + %15 = arith.divsi %14, %11 : i32 + %16 = arith.muli %13, %c128_i32 : i32 + %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %23 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %24 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %25 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %26 = arith.addi %23, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.addi %24, %19 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %28 = arith.addi %25, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %29 = arith.muli %15, %c128_i32 : i32 + %30 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %31 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %32 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %33 = arith.addi %30, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %34 = arith.addi %31, %20 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %35 = arith.addi %32, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %36 = tt.splat %arg3 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %37 = tt.splat %arg3 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %38 = arith.remsi %26, %36 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %39 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %40 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %41 = arith.remsi %33, %39 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %42 = arith.muli %1, %c32_i32 : i32 + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %45 = tt.splat %42 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %46 = tt.splat %42 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %47 = arith.addi %45, %43 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %48 = arith.addi %46, %44 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %49 = tt.expand_dims %38 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> + %50 = tt.splat %arg6 : (i32) -> tensor<128x1xi32, #blocked1> + %51 = arith.muli %49, %50 : tensor<128x1xi32, #blocked1> + %52 = tt.expand_dims %47 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x32xi32, #blocked1> + %53 = tt.broadcast %51 : (tensor<128x1xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + %54 = tt.broadcast %52 : (tensor<1x32xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + %55 = arith.addi %53, %54 : tensor<128x32xi32, #blocked1> + %56 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x32x!tt.ptr, #blocked1> + %57 = tt.addptr %56, %55 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %58 = tt.expand_dims %48 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked> + %59 = tt.expand_dims %41 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %60 = tt.splat %arg7 : (i32) -> tensor<1x128xi32, #blocked> + %61 = arith.muli %59, %60 : tensor<1x128xi32, #blocked> + %62 = tt.broadcast %58 : (tensor<32x1xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %63 = tt.broadcast %61 : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %64 = arith.addi %62, %63 : tensor<32x128xi32, #blocked> + %65 = tt.splat %arg1 : (!tt.ptr) -> tensor<32x128x!tt.ptr, #blocked> + %66 = tt.addptr %65, %64 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %67 = arith.addi %arg5, %c31_i32 : i32 + %68 = arith.divsi %67, %c32_i32 : i32 + %69 = arith.index_cast %68 : i32 to index + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #shared> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #shared1> + %70:3 = scf.for %arg9 = %c0 to %69 step %c1 iter_args(%arg10 = %cst, %arg11 = %57, %arg12 = %66) -> (tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>) { + %96:2 = scf.for %arg13 = %c0 to %69 step %c1 iter_args(%arg14 = %cst_2, %arg15 = %cst_3) -> (tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1>) { + %89 = tt.load %arg11 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #blocked1> + %90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> + %91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> + %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> + scf.yield %91, %92 : tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1> + } + %93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> + } + %71 = arith.truncf %70#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %72 = tt.expand_dims %27 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> + %73 = tt.splat %arg8 : (i32) -> tensor<128x1xi32, #blocked2> + %74 = arith.muli %72, %73 : tensor<128x1xi32, #blocked2> + %75 = tt.expand_dims %34 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2> + %76 = tt.broadcast %74 : (tensor<128x1xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %77 = tt.broadcast %75 : (tensor<1x128xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %78 = arith.addi %76, %77 : tensor<128x128xi32, #blocked2> + %79 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> + %80 = tt.addptr %79, %78 : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> + %81 = "triton_gpu.cmpi"(%28, %37) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %82 = tt.expand_dims %81 {axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> + %83 = "triton_gpu.cmpi"(%35, %40) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %84 = tt.expand_dims %83 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> + %85 = tt.broadcast %82 : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + %86 = tt.broadcast %84 : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + %87 = arith.andi %85, %86 : tensor<128x128xi1, #blocked2> + %88 = triton_gpu.convert_layout %71 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #blocked2> + tt.store %80, %88, %87 {cache = 1 : i32, evict = 1 : i32} : tensor<128x128xf16, #blocked2> + tt.return + + // CHECK-NEXT: %cst = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #mma> + // CHECK-NEXT: %cst_0 = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<32> : tensor<32x128xi32, #blocked> + // CHECK-NEXT: %cst_1 = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<32> : tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %c31_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 31 : i32 + // CHECK-NEXT: %c127_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 127 : i32 + // CHECK-NEXT: %c1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 1 : index + // CHECK-NEXT: %c0 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : index + // CHECK-NEXT: %c32_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 32 : i32 + // CHECK-NEXT: %c128_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 128 : i32 + // CHECK-NEXT: %c8_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 8 : i32 + // CHECK-NEXT: %0 = tt.get_program_id {async_agent = dense<[0, 1]> : vector<2xi32>, axis = 0 : i32} : i32 + // CHECK-NEXT: %1 = tt.get_program_id {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : i32 + // CHECK-NEXT: %2 = arith.addi %arg3, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %3 = arith.divsi %2, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %4 = arith.addi %arg4, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %5 = arith.divsi %4, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %6 = arith.muli %5, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %7 = arith.divsi %0, %6 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %8 = arith.muli %7, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %9 = arith.subi %3, %8 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %10 = arith.cmpi slt, %9, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %11 = arith.select %10, %9, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %12 = arith.remsi %0, %11 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %13 = arith.addi %8, %12 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %14 = arith.remsi %0, %6 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %15 = arith.divsi %14, %11 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %16 = arith.muli %13, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %17 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + // CHECK-NEXT: %18 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + // CHECK-NEXT: %19 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %20 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %21 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %22 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %23 = tt.splat %16 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + // CHECK-NEXT: %24 = tt.splat %16 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %25 = tt.splat %16 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %26 = arith.addi %23, %17 {async_agent = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + // CHECK-NEXT: %27 = arith.addi %24, %19 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %28 = arith.addi %25, %21 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %29 = arith.muli %15, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %30 = tt.splat %29 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + // CHECK-NEXT: %31 = tt.splat %29 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %32 = tt.splat %29 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %33 = arith.addi %30, %18 {async_agent = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + // CHECK-NEXT: %34 = arith.addi %31, %20 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %35 = arith.addi %32, %22 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %36 = tt.splat %arg3 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + // CHECK-NEXT: %37 = tt.splat %arg3 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %38 = arith.remsi %26, %36 {async_agent = dense<0> : vector<1xi32>, tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + // CHECK-NEXT: %39 = tt.splat %arg4 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + // CHECK-NEXT: %40 = tt.splat %arg4 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %41 = arith.remsi %33, %39 {async_agent = dense<0> : vector<1xi32>, tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + // CHECK-NEXT: %42 = arith.muli %1, %c32_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + // CHECK-NEXT: %43 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + // CHECK-NEXT: %44 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + // CHECK-NEXT: %45 = tt.splat %42 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + // CHECK-NEXT: %46 = tt.splat %42 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + // CHECK-NEXT: %47 = arith.addi %45, %43 {async_agent = dense<0> : vector<1xi32>} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + // CHECK-NEXT: %48 = arith.addi %46, %44 {async_agent = dense<0> : vector<1xi32>} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + // CHECK-NEXT: %49 = tt.expand_dims %38 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> + // CHECK-NEXT: %50 = tt.splat %arg6 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128x1xi32, #blocked1> + // CHECK-NEXT: %51 = arith.muli %49, %50 {async_agent = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked1> + // CHECK-NEXT: %52 = tt.expand_dims %47 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x32xi32, #blocked1> + // CHECK-NEXT: %53 = tt.broadcast %51 {async_agent = dense<0> : vector<1xi32>} : (tensor<128x1xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %54 = tt.broadcast %52 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x32xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %55 = arith.addi %53, %54 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %56 = tt.splat %arg0 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr) -> tensor<128x32x!tt.ptr, #blocked1> + // CHECK-NEXT: %57 = tt.addptr %56, %55 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %58 = tt.expand_dims %48 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked> + // CHECK-NEXT: %59 = tt.expand_dims %41 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + // CHECK-NEXT: %60 = tt.splat %arg7 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<1x128xi32, #blocked> + // CHECK-NEXT: %61 = arith.muli %59, %60 {async_agent = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked> + // CHECK-NEXT: %62 = tt.broadcast %58 {async_agent = dense<0> : vector<1xi32>} : (tensor<32x1xi32, #blocked>) -> tensor<32x128xi32, #blocked> + // CHECK-NEXT: %63 = tt.broadcast %61 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + // CHECK-NEXT: %64 = arith.addi %62, %63 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128xi32, #blocked> + // CHECK-NEXT: %65 = tt.splat %arg1 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr) -> tensor<32x128x!tt.ptr, #blocked> + // CHECK-NEXT: %66 = tt.addptr %65, %64 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + // CHECK-NEXT: %67 = arith.addi %arg5, %c31_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %68 = arith.divsi %67, %c32_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %69 = arith.index_cast %68 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 to index + // CHECK-NEXT: %cst_2 = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x32xf16, #shared> + // CHECK-NEXT: %cst_3 = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<32x128xf16, #shared1> + // CHECK-NEXT: %70:3 = scf.for %arg9 = %c0 to %69 step %c1 iter_args(%arg10 = %cst, %arg11 = %57, %arg12 = %66) -> (tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>) { + // CHECK-NEXT: %89:2 = scf.for %arg13 = %c0 to %69 step %c1 iter_args(%arg14 = %cst_2, %arg15 = %cst_3) -> (tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1>) { + // CHECK-NEXT: %93 = tt.load %arg11 {async_agent = dense<0> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #blocked1> + // CHECK-NEXT: %94 = tt.load %arg12 {async_agent = dense<0> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> + // CHECK-NEXT: %95 = triton_gpu.convert_layout %93 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> + // CHECK-NEXT: %96 = triton_gpu.convert_layout %94 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> + // CHECK-NEXT: scf.yield {async_agent = dense<[0, 1]> : vector<2xi32>} %95, %96 : tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1> + // CHECK-NEXT: } {async_agent = dense<[0, 1]> : vector<2xi32>} + // CHECK-NEXT: %90 = tt.dot %89#0, %89#1, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + // CHECK-NEXT: %91 = tt.addptr %arg11, %cst_1 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %92 = tt.addptr %arg12, %cst_0 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + // CHECK-NEXT: scf.yield {async_agent = dense<[0, 1]> : vector<2xi32>} %90, %91, %92 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> + // CHECK-NEXT: } {async_agent = dense<[0, 1]> : vector<2xi32>} + // CHECK-NEXT: %71 = arith.truncf %70#0 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + // CHECK-NEXT: %72 = tt.expand_dims %27 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> + // CHECK-NEXT: %73 = tt.splat %arg8 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128x1xi32, #blocked2> + // CHECK-NEXT: %74 = arith.muli %72, %73 {async_agent = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked2> + // CHECK-NEXT: %75 = tt.expand_dims %34 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2> + // CHECK-NEXT: %76 = tt.broadcast %74 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + // CHECK-NEXT: %77 = tt.broadcast %75 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x128xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + // CHECK-NEXT: %78 = arith.addi %76, %77 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xi32, #blocked2> + // CHECK-NEXT: %79 = tt.splat %arg2 {async_agent = dense<1> : vector<1xi32>} : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> + // CHECK-NEXT: %80 = tt.addptr %79, %78 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> + // CHECK-NEXT: %81 = "triton_gpu.cmpi"(%28, %37) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %82 = tt.expand_dims %81 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> + // CHECK-NEXT: %83 = "triton_gpu.cmpi"(%35, %40) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %84 = tt.expand_dims %83 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> + // CHECK-NEXT: %85 = tt.broadcast %82 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + // CHECK-NEXT: %86 = tt.broadcast %84 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + // CHECK-NEXT: %87 = arith.andi %85, %86 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xi1, #blocked2> + // CHECK-NEXT: %88 = triton_gpu.convert_layout %71 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #blocked2> + // CHECK-NEXT: tt.store %80, %88, %87 {async_agent = dense<1> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32} : tensor<128x128xf16, #blocked2> + + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32} { + // CHECK-LABEL: @if_in_for_gemm + tt.func public @if_in_for_gemm(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_0 = arith.constant dense<32> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant dense<32> : tensor<128x32xi32, #blocked1> + %c31_i32 = arith.constant 31 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c32_i32 = arith.constant 32 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.addi %arg3, %c127_i32 : i32 + %3 = arith.divsi %2, %c128_i32 : i32 + %4 = arith.addi %arg4, %c127_i32 : i32 + %5 = arith.divsi %4, %c128_i32 : i32 + %6 = arith.muli %5, %c8_i32 : i32 + %7 = arith.divsi %0, %6 : i32 + %8 = arith.muli %7, %c8_i32 : i32 + %9 = arith.subi %3, %8 : i32 + %10 = arith.cmpi slt, %9, %c8_i32 : i32 + %11 = arith.select %10, %9, %c8_i32 : i32 + %12 = arith.remsi %0, %11 : i32 + %13 = arith.addi %8, %12 : i32 + %14 = arith.remsi %0, %6 : i32 + %15 = arith.divsi %14, %11 : i32 + %16 = arith.muli %13, %c128_i32 : i32 + %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %23 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %24 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %25 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %26 = arith.addi %23, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.addi %24, %19 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %28 = arith.addi %25, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %29 = arith.muli %15, %c128_i32 : i32 + %30 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %31 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %32 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %33 = arith.addi %30, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %34 = arith.addi %31, %20 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %35 = arith.addi %32, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %36 = tt.splat %arg3 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %37 = tt.splat %arg3 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %38 = arith.remsi %26, %36 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %39 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %40 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %41 = arith.remsi %33, %39 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %42 = arith.muli %1, %c32_i32 : i32 + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %45 = tt.splat %42 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %46 = tt.splat %42 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %47 = arith.addi %45, %43 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %48 = arith.addi %46, %44 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %49 = tt.expand_dims %38 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> + %50 = tt.splat %arg6 : (i32) -> tensor<128x1xi32, #blocked1> + %51 = arith.muli %49, %50 : tensor<128x1xi32, #blocked1> + %52 = tt.expand_dims %47 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x32xi32, #blocked1> + %53 = tt.broadcast %51 : (tensor<128x1xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + %54 = tt.broadcast %52 : (tensor<1x32xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + %55 = arith.addi %53, %54 : tensor<128x32xi32, #blocked1> + %56 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x32x!tt.ptr, #blocked1> + %57 = tt.addptr %56, %55 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %58 = tt.expand_dims %48 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked> + %59 = tt.expand_dims %41 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %60 = tt.splat %arg7 : (i32) -> tensor<1x128xi32, #blocked> + %61 = arith.muli %59, %60 : tensor<1x128xi32, #blocked> + %62 = tt.broadcast %58 : (tensor<32x1xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %63 = tt.broadcast %61 : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %64 = arith.addi %62, %63 : tensor<32x128xi32, #blocked> + %65 = tt.splat %arg1 : (!tt.ptr) -> tensor<32x128x!tt.ptr, #blocked> + %66 = tt.addptr %65, %64 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %67 = arith.addi %arg5, %c31_i32 : i32 + %68 = arith.divsi %67, %c32_i32 : i32 + %69 = arith.index_cast %68 : i32 to index + %70:3 = scf.for %arg9 = %c0 to %69 step %c1 iter_args(%arg10 = %cst, %arg11 = %57, %arg12 = %66) -> (tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>) { + %arg9_i32 = arith.index_cast %arg9 : index to i32 + %96 = arith.cmpi ne, %arg9_i32, %c31_i32 : i32 + %89 = scf.if %96 -> (tensor<128x32xf16, #blocked1>) { + %r0_0 = arith.select %96, %c31_i32, %c127_i32 : i32 + %r0_1 = tt.splat %r0_0 : (i32) -> tensor<128x32xi32, #blocked1> + %new_addr = tt.addptr %arg11, %r0_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %new_89 = tt.load %new_addr {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #blocked1> + scf.yield %new_89 : tensor<128x32xf16, #blocked1> + } else { + %new_89 = tt.load %arg11 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #blocked1> + scf.yield %new_89 : tensor<128x32xf16, #blocked1> + } + %90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> + %91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> + %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> + %93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %base_94 = scf.if %96 -> (tensor<128x32x!tt.ptr, #blocked1>) { + %r1_0 = arith.select %96, %c31_i32, %c127_i32 : i32 + %r1_1 = tt.splat %r1_0 : (i32) -> tensor<128x32xi32, #blocked1> + %98 = tt.addptr %arg11, %r1_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + scf.yield %98 : tensor<128x32x!tt.ptr, #blocked1> + } else { + %98 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + scf.yield %98 : tensor<128x32x!tt.ptr, #blocked1> + } + %94 = tt.addptr %base_94, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> + } + %71 = arith.truncf %70#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %72 = tt.expand_dims %27 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> + %73 = tt.splat %arg8 : (i32) -> tensor<128x1xi32, #blocked2> + %74 = arith.muli %72, %73 : tensor<128x1xi32, #blocked2> + %75 = tt.expand_dims %34 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2> + %76 = tt.broadcast %74 : (tensor<128x1xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %77 = tt.broadcast %75 : (tensor<1x128xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %78 = arith.addi %76, %77 : tensor<128x128xi32, #blocked2> + %79 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> + %80 = tt.addptr %79, %78 : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> + %81 = "triton_gpu.cmpi"(%28, %37) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %82 = tt.expand_dims %81 {axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> + %83 = "triton_gpu.cmpi"(%35, %40) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %84 = tt.expand_dims %83 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> + %85 = tt.broadcast %82 : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + %86 = tt.broadcast %84 : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + %87 = arith.andi %85, %86 : tensor<128x128xi1, #blocked2> + %88 = triton_gpu.convert_layout %71 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #blocked2> + tt.store %80, %88, %87 {cache = 1 : i32, evict = 1 : i32} : tensor<128x128xf16, #blocked2> + tt.return + // CHECK-NEXT: %cst = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #mma> + // CHECK-NEXT: %cst_0 = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<32> : tensor<32x128xi32, #blocked> + // CHECK-NEXT: %cst_1 = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<32> : tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %c31_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 31 : i32 + // CHECK-NEXT: %c127_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 127 : i32 + // CHECK-NEXT: %c1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 1 : index + // CHECK-NEXT: %c0 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : index + // CHECK-NEXT: %c32_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 32 : i32 + // CHECK-NEXT: %c128_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 128 : i32 + // CHECK-NEXT: %c8_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 8 : i32 + // CHECK-NEXT: %0 = tt.get_program_id {async_agent = dense<[0, 1]> : vector<2xi32>, axis = 0 : i32} : i32 + // CHECK-NEXT: %1 = tt.get_program_id {async_agent = dense<[0, 1]> : vector<2xi32>, axis = 1 : i32} : i32 + // CHECK-NEXT: %2 = arith.addi %arg3, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %3 = arith.divsi %2, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %4 = arith.addi %arg4, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %5 = arith.divsi %4, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %6 = arith.muli %5, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %7 = arith.divsi %0, %6 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %8 = arith.muli %7, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %9 = arith.subi %3, %8 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %10 = arith.cmpi slt, %9, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %11 = arith.select %10, %9, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %12 = arith.remsi %0, %11 {async_agent = dense<1> : vector<1xi32>} : i32 + // CHECK-NEXT: %13 = arith.addi %8, %12 {async_agent = dense<1> : vector<1xi32>} : i32 + // CHECK-NEXT: %14 = arith.remsi %0, %6 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %15 = arith.divsi %14, %11 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %16 = arith.muli %13, %c128_i32 {async_agent = dense<1> : vector<1xi32>} : i32 + // CHECK-NEXT: %17 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + // CHECK-NEXT: %18 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + // CHECK-NEXT: %19 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %20 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %21 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %22 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %23 = tt.splat %16 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + // CHECK-NEXT: %24 = tt.splat %16 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %25 = tt.splat %16 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %26 = arith.addi %23, %17 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + // CHECK-NEXT: %27 = arith.addi %24, %19 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %28 = arith.addi %25, %21 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %29 = arith.muli %15, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %30 = tt.splat %29 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + // CHECK-NEXT: %31 = tt.splat %29 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %32 = tt.splat %29 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %33 = arith.addi %30, %18 {async_agent = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + // CHECK-NEXT: %34 = arith.addi %31, %20 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %35 = arith.addi %32, %22 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %36 = tt.splat %arg3 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + // CHECK-NEXT: %37 = tt.splat %arg3 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %38 = arith.remsi %26, %36 {async_agent = dense<1> : vector<1xi32>, tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + // CHECK-NEXT: %39 = tt.splat %arg4 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + // CHECK-NEXT: %40 = tt.splat %arg4 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %41 = arith.remsi %33, %39 {async_agent = dense<0> : vector<1xi32>, tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + // CHECK-NEXT: %42 = arith.muli %1, %c32_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %43 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + // CHECK-NEXT: %44 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + // CHECK-NEXT: %45 = tt.splat %42 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + // CHECK-NEXT: %46 = tt.splat %42 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + // CHECK-NEXT: %47 = arith.addi %45, %43 {async_agent = dense<1> : vector<1xi32>} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + // CHECK-NEXT: %48 = arith.addi %46, %44 {async_agent = dense<0> : vector<1xi32>} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + // CHECK-NEXT: %49 = tt.expand_dims %38 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> + // CHECK-NEXT: %50 = tt.splat %arg6 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128x1xi32, #blocked1> + // CHECK-NEXT: %51 = arith.muli %49, %50 {async_agent = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked1> + // CHECK-NEXT: %52 = tt.expand_dims %47 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x32xi32, #blocked1> + // CHECK-NEXT: %53 = tt.broadcast %51 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %54 = tt.broadcast %52 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x32xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %55 = arith.addi %53, %54 {async_agent = dense<1> : vector<1xi32>} : tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %56 = tt.splat %arg0 {async_agent = dense<1> : vector<1xi32>} : (!tt.ptr) -> tensor<128x32x!tt.ptr, #blocked1> + // CHECK-NEXT: %57 = tt.addptr %56, %55 {async_agent = dense<1> : vector<1xi32>} : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %58 = tt.expand_dims %48 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked> + // CHECK-NEXT: %59 = tt.expand_dims %41 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + // CHECK-NEXT: %60 = tt.splat %arg7 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<1x128xi32, #blocked> + // CHECK-NEXT: %61 = arith.muli %59, %60 {async_agent = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked> + // CHECK-NEXT: %62 = tt.broadcast %58 {async_agent = dense<0> : vector<1xi32>} : (tensor<32x1xi32, #blocked>) -> tensor<32x128xi32, #blocked> + // CHECK-NEXT: %63 = tt.broadcast %61 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + // CHECK-NEXT: %64 = arith.addi %62, %63 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128xi32, #blocked> + // CHECK-NEXT: %65 = tt.splat %arg1 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr) -> tensor<32x128x!tt.ptr, #blocked> + // CHECK-NEXT: %66 = tt.addptr %65, %64 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + // CHECK-NEXT: %67 = arith.addi %arg5, %c31_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %68 = arith.divsi %67, %c32_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + // CHECK-NEXT: %69 = arith.index_cast %68 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 to index + // CHECK-NEXT: %70:3 = scf.for %arg9 = %c0 to %69 step %c1 iter_args(%arg10 = %cst, %arg11 = %57, %arg12 = %66) -> (tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>) { + // CHECK-NEXT: %89 = arith.index_cast %arg9 {async_agent = dense<1> : vector<1xi32>} : index to i32 + // CHECK-NEXT: %90 = arith.cmpi ne, %89, %c31_i32 {async_agent = dense<1> : vector<1xi32>} : i32 + // CHECK-NEXT: %91 = scf.if %90 -> (tensor<128x32xf16, #blocked1>) { + // CHECK-NEXT: %99 = arith.select %90, %c31_i32, %c127_i32 {async_agent = dense<1> : vector<1xi32>} : i32 + // CHECK-NEXT: %100 = tt.splat %99 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %101 = tt.addptr %arg11, %100 {async_agent = dense<1> : vector<1xi32>} : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %102 = tt.load %101 {async_agent = dense<1> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #blocked1> + // CHECK-NEXT: scf.yield {async_agent = dense<1> : vector<1xi32>} %102 : tensor<128x32xf16, #blocked1> + // CHECK-NEXT: } else { + // CHECK-NEXT: %99 = tt.load %arg11 {async_agent = dense<1> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #blocked1> + // CHECK-NEXT: scf.yield {async_agent = dense<1> : vector<1xi32>} %99 : tensor<128x32xf16, #blocked1> + // CHECK-NEXT: } {async_agent = dense<1> : vector<1xi32>} + // CHECK-NEXT: %92 = tt.load %arg12 {async_agent = dense<0> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> + // CHECK-NEXT: %93 = triton_gpu.convert_layout %91 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> + // CHECK-NEXT: %94 = triton_gpu.convert_layout %92 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> + // CHECK-NEXT: %95 = tt.dot %93, %94, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + // CHECK-NEXT: %96 = scf.if %90 -> (tensor<128x32x!tt.ptr, #blocked1>) { + // CHECK-NEXT: %99 = arith.select %90, %c31_i32, %c127_i32 {async_agent = dense<1> : vector<1xi32>} : i32 + // CHECK-NEXT: %100 = tt.splat %99 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %101 = tt.addptr %arg11, %100 {async_agent = dense<1> : vector<1xi32>} : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + // CHECK-NEXT: scf.yield {async_agent = dense<1> : vector<1xi32>} %101 : tensor<128x32x!tt.ptr, #blocked1> + // CHECK-NEXT: } else { + // CHECK-NEXT: %99 = tt.addptr %arg11, %cst_1 {async_agent = dense<1> : vector<1xi32>} : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + // CHECK-NEXT: scf.yield {async_agent = dense<1> : vector<1xi32>} %99 : tensor<128x32x!tt.ptr, #blocked1> + // CHECK-NEXT: } {async_agent = dense<1> : vector<1xi32>} + // CHECK-NEXT: %97 = tt.addptr %96, %cst_1 {async_agent = dense<1> : vector<1xi32>} : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + // CHECK-NEXT: %98 = tt.addptr %arg12, %cst_0 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + // CHECK-NEXT: scf.yield {async_agent = dense<[0, 1]> : vector<2xi32>} %95, %97, %98 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> + // CHECK-NEXT: } {async_agent = dense<[0, 1]> : vector<2xi32>} + // CHECK-NEXT: %71 = arith.truncf %70#0 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + // CHECK-NEXT: %72 = tt.expand_dims %27 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> + // CHECK-NEXT: %73 = tt.splat %arg8 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128x1xi32, #blocked2> + // CHECK-NEXT: %74 = arith.muli %72, %73 {async_agent = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked2> + // CHECK-NEXT: %75 = tt.expand_dims %34 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2> + // CHECK-NEXT: %76 = tt.broadcast %74 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + // CHECK-NEXT: %77 = tt.broadcast %75 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x128xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + // CHECK-NEXT: %78 = arith.addi %76, %77 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xi32, #blocked2> + // CHECK-NEXT: %79 = tt.splat %arg2 {async_agent = dense<1> : vector<1xi32>} : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> + // CHECK-NEXT: %80 = tt.addptr %79, %78 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> + // CHECK-NEXT: %81 = "triton_gpu.cmpi"(%28, %37) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %82 = tt.expand_dims %81 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> + // CHECK-NEXT: %83 = "triton_gpu.cmpi"(%35, %40) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %84 = tt.expand_dims %83 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> + // CHECK-NEXT: %85 = tt.broadcast %82 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + // CHECK-NEXT: %86 = tt.broadcast %84 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + // CHECK-NEXT: %87 = arith.andi %85, %86 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xi1, #blocked2> + // CHECK-NEXT: %88 = triton_gpu.convert_layout %71 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #blocked2> + // CHECK-NEXT: tt.store %80, %88, %87 {async_agent = dense<1> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32} : tensor<128x128xf16, #blocked2> + } +} diff --git a/test/TritonGPU/wsmaterialization.mlir b/test/TritonGPU/wsmaterialization.mlir new file mode 100644 index 000000000000..4ab8be6c5d96 --- /dev/null +++ b/test/TritonGPU/wsmaterialization.mlir @@ -0,0 +1,414 @@ +// RUN: triton-opt -split-input-file -triton-nvidia-gpu-ws-materialization='compute-capability=90' %s | FileCheck %s + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32} { + // CHECK-LABEL: @simple_gemm + // CHECK: triton_nvidia_gpu.alloc_mbarrier + // CHECK: scf.if + // CHECK: scf.for + // CHECK: triton_nvidia_gpu.extract_mbarrier + // CHECK: triton_nvidia_gpu.mbarrier_wait + // CHECK: triton_gpu.insert_slice + // CHECK: triton_gpu.insert_slice + // CHECK: triton_nvidia_gpu.extract_mbarrier + // CHECK: triton_nvidia_gpu.mbarrier_arrive + // CHECK: scf.yield + // CHECK: scf.if + // CHECK: triton_nvidia_gpu.bar_wait + // CHECK: scf.for + // CHECK: triton_nvidia_gpu.extract_mbarrier + // CHECK: triton_nvidia_gpu.mbarrier_wait + // CHECK: triton_gpu.extract_slice + // CHECK: triton_gpu.extract_slice + // CHECK: tt.dot + // CHECK: triton_nvidia_gpu.extract_mbarrier + // CHECK: triton_nvidia_gpu.mbarrier_arrive + // CHECK: scf.yield + // CHECK: triton_nvidia_gpu.bar_arrive + // CHECK: triton_nvidia_gpu.bar_wait + // CHECK: tt.store + // CHECK: triton_nvidia_gpu.bar_arrive + tt.func public @simple_gemm(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { + %0 = triton_gpu.alloc_tensor : tensor<3x32x128xf16, #shared> + %1 = triton_gpu.alloc_tensor : tensor<3x128x32xf16, #shared1> + %2 = triton_nvidia_gpu.create_token {num = 3 : i32} : tensor<3x!triton_nvidia_gpu.token> + %3 = triton_nvidia_gpu.create_mutex : !triton_nvidia_gpu.mutex + %4 = triton_nvidia_gpu.create_mutex : !triton_nvidia_gpu.mutex + %5 = triton_nvidia_gpu.get_agent_id : i32 + %c0_i32 = arith.constant 0 : i32 + %6 = arith.cmpi eq, %5, %c0_i32 : i32 + scf.if %6 { + %cst = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<32> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<32> : tensor<128x32xi32, #blocked1> + %c31_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 31 : i32 + %c127_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 127 : i32 + %c1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 1 : index + %c0 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : index + %c32_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 32 : i32 + %c128_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 128 : i32 + %c8_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 8 : i32 + %8 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %9 = tt.get_program_id y {async_agent = dense<0> : vector<1xi32>} : i32 + %10 = arith.addi %arg3, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %11 = arith.divsi %10, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = arith.addi %arg4, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %13 = arith.divsi %12, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %14 = arith.muli %13, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %15 = arith.divsi %8, %14 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %16 = arith.muli %15, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %17 = arith.subi %11, %16 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %18 = arith.cmpi slt, %17, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %19 = arith.select %18, %17, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %20 = arith.remsi %8, %19 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %21 = arith.addi %16, %20 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %22 = arith.remsi %8, %14 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %23 = arith.divsi %22, %19 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %24 = arith.muli %21, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %25 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %26 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %27 = tt.splat %24 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %28 = arith.addi %27, %25 {async_agent = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %29 = arith.muli %23, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %30 = tt.splat %29 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %31 = arith.addi %30, %26 {async_agent = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %32 = tt.splat %arg3 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %33 = arith.remsi %28, %32 {async_agent = dense<0> : vector<1xi32>, tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %34 = tt.splat %arg4 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %35 = arith.remsi %31, %34 {async_agent = dense<0> : vector<1xi32>, tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %36 = arith.muli %9, %c32_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %37 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %38 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %39 = tt.splat %36 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %40 = tt.splat %36 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %41 = arith.addi %39, %37 {async_agent = dense<0> : vector<1xi32>} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %42 = arith.addi %40, %38 {async_agent = dense<0> : vector<1xi32>} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %43 = tt.expand_dims %33 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> + %44 = tt.splat %arg6 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128x1xi32, #blocked1> + %45 = arith.muli %43, %44 {async_agent = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked1> + %46 = tt.expand_dims %41 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x32xi32, #blocked1> + %47 = tt.broadcast %45 {async_agent = dense<0> : vector<1xi32>} : (tensor<128x1xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + %48 = tt.broadcast %46 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x32xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + %49 = arith.addi %47, %48 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32xi32, #blocked1> + %50 = tt.splat %arg0 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr) -> tensor<128x32x!tt.ptr, #blocked1> + %51 = tt.addptr %50, %49 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %52 = tt.expand_dims %42 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked> + %53 = tt.expand_dims %35 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %54 = tt.splat %arg7 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<1x128xi32, #blocked> + %55 = arith.muli %53, %54 {async_agent = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked> + %56 = tt.broadcast %52 {async_agent = dense<0> : vector<1xi32>} : (tensor<32x1xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %57 = tt.broadcast %55 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %58 = arith.addi %56, %57 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128xi32, #blocked> + %59 = tt.splat %arg1 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr) -> tensor<32x128x!tt.ptr, #blocked> + %60 = tt.addptr %59, %58 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %61 = arith.addi %arg5, %c31_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %62 = arith.divsi %61, %c32_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %63 = arith.index_cast %62 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 to index + %c0_i32_2 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %64:3 = scf.for %arg9 = %c0 to %63 step %c1 iter_args(%arg10 = %51, %arg11 = %60, %arg12 = %c0_i32_2) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, i32) { + triton_nvidia_gpu.producer_acquire %2, %arg12 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + %65 = triton_gpu.insert_slice %arg10, %1, %arg12 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32x!tt.ptr, #blocked1> -> tensor<3x128x32xf16, #shared1> + %66 = triton_gpu.insert_slice %arg11, %0, %arg12 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128x!tt.ptr, #blocked> -> tensor<3x32x128xf16, #shared> + %67 = tt.addptr %arg10, %cst_1 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %68 = tt.addptr %arg11, %cst {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %c1_i32_3 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32 + %c3_i32 = arith.constant {async_agent = dense<0> : vector<1xi32>} 3 : i32 + %69 = arith.addi %arg12, %c1_i32_3 {async_agent = dense<0> : vector<1xi32>} : i32 + %70 = arith.remsi %69, %c3_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + triton_nvidia_gpu.producer_commit %2, %arg12 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + scf.yield %67, %68, %70 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, i32 + } {async_agent = dense<0> : vector<1xi32>} + } + %c1_i32 = arith.constant 1 : i32 + %c1_i32_0 = arith.constant 1 : i32 + %7 = arith.cmpi sge, %5, %c1_i32_0 : i32 + scf.if %7 { + %cst = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #mma> + %c31_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 31 : i32 + %c127_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 127 : i32 + %c1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 1 : index + %c0 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : index + %c32_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 32 : i32 + %c128_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 128 : i32 + %c8_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 8 : i32 + %8 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %9 = arith.addi %arg3, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %10 = arith.divsi %9, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %11 = arith.addi %arg4, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = arith.divsi %11, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %13 = arith.muli %12, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %14 = arith.divsi %8, %13 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %15 = arith.muli %14, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %16 = arith.subi %10, %15 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %17 = arith.cmpi slt, %16, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %18 = arith.select %17, %16, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %19 = arith.remsi %8, %18 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %20 = arith.addi %15, %19 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %21 = arith.remsi %8, %13 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %22 = arith.divsi %21, %18 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %23 = arith.muli %20, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %24 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %25 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %26 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %27 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %28 = tt.splat %23 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %29 = tt.splat %23 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %30 = arith.addi %28, %24 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %31 = arith.addi %29, %26 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %32 = arith.muli %22, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %33 = tt.splat %32 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %34 = tt.splat %32 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %35 = arith.addi %33, %25 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %36 = arith.addi %34, %27 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %37 = tt.splat %arg3 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %38 = tt.splat %arg4 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %39 = arith.addi %arg5, %c31_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %40 = arith.divsi %39, %c32_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %41 = arith.index_cast %40 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 to index + %c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 + triton_nvidia_gpu.lock %3 {mutex.barId = dense<1> : vector<1xi32>, mutex.numThreads = dense<256> : vector<1xi32>} : !triton_nvidia_gpu.mutex + %42:2 = scf.for %arg9 = %c0 to %41 step %c1 iter_args(%arg10 = %cst, %arg11 = %c0_i32_1) -> (tensor<128x128xf32, #mma>, i32) { + triton_nvidia_gpu.consumer_wait %2, %arg11 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + %62 = triton_gpu.extract_slice %1[%arg11, 0, 0] [1, 128, 32] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x128x32xf16, #shared1> to tensor<128x32xf16, #shared1> + %63 = triton_gpu.extract_slice %0[%arg11, 0, 0] [1, 32, 128] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x32x128xf16, #shared> to tensor<32x128xf16, #shared> + %64 = triton_gpu.convert_layout %62 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #shared1>) -> tensor<128x32xf16, #shared1> + %65 = triton_gpu.convert_layout %63 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #shared>) -> tensor<32x128xf16, #shared> + %66 = tt.dot %64, %65, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared1> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma> + %c1_i32_2 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32 + %c3_i32 = arith.constant {async_agent = dense<1> : vector<1xi32>} 3 : i32 + %67 = arith.addi %arg11, %c1_i32_2 {async_agent = dense<1> : vector<1xi32>} : i32 + %68 = arith.remsi %67, %c3_i32 {async_agent = dense<1> : vector<1xi32>} : i32 + triton_nvidia_gpu.consumer_release %2, %arg11 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + scf.yield %66, %68 : tensor<128x128xf32, #mma>, i32 + } {async_agent = dense<1> : vector<1xi32>} + triton_nvidia_gpu.unlock %3 {mutex.barId = dense<2> : vector<1xi32>, mutex.numThreads = dense<256> : vector<1xi32>} : !triton_nvidia_gpu.mutex + triton_nvidia_gpu.lock %4 {mutex.barId = dense<3> : vector<1xi32>, mutex.numThreads = dense<256> : vector<1xi32>} : !triton_nvidia_gpu.mutex + %43 = arith.truncf %42#0 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %44 = tt.expand_dims %30 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> + %45 = tt.splat %arg8 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128x1xi32, #blocked2> + %46 = arith.muli %44, %45 {async_agent = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked2> + %47 = tt.expand_dims %35 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2> + %48 = tt.broadcast %46 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %49 = tt.broadcast %47 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x128xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %50 = arith.addi %48, %49 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xi32, #blocked2> + %51 = tt.splat %arg2 {async_agent = dense<1> : vector<1xi32>} : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> + %52 = tt.addptr %51, %50 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> + %53 = "triton_gpu.cmpi"(%31, %37) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %54 = tt.expand_dims %53 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> + %55 = "triton_gpu.cmpi"(%36, %38) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %56 = tt.expand_dims %55 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> + %57 = tt.broadcast %54 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + %58 = tt.broadcast %56 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + %59 = arith.andi %57, %58 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xi1, #blocked2> + %60 = triton_gpu.convert_layout %43 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #blocked2> + tt.store %52, %60, %59 {async_agent = dense<1> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32} : tensor<128x128xf16, #blocked2> + triton_nvidia_gpu.unlock %4 {mutex.barId = dense<4> : vector<1xi32>, mutex.numThreads = dense<256> : vector<1xi32>} : !triton_nvidia_gpu.mutex + } + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability" = 90 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: @matmal_from_wsmutex + // CHECK: triton_nvidia_gpu.alloc_mbarrier + // CHECK: scf.if + // CHECK: scf.for + // CHECK: triton_nvidia_gpu.extract_mbarrier + // CHECK: triton_nvidia_gpu.mbarrier_wait + // CHECK: triton_gpu.insert_slice + // CHECK: triton_gpu.insert_slice + // CHECK: triton_nvidia_gpu.extract_mbarrier + // CHECK: triton_nvidia_gpu.mbarrier_arrive + // CHECK: scf.yield + // CHECK: scf.if + // CHECK: triton_nvidia_gpu.bar_wait + // CHECK: scf.for + // CHECK: triton_nvidia_gpu.extract_mbarrier + // CHECK: triton_nvidia_gpu.mbarrier_wait + // CHECK: triton_gpu.extract_slice + // CHECK: triton_gpu.extract_slice + // CHECK: tt.dot + // CHECK: triton_nvidia_gpu.extract_mbarrier + // CHECK: triton_nvidia_gpu.mbarrier_arrive + // CHECK: scf.yield + // CHECK: triton_nvidia_gpu.bar_arrive + // CHECK: triton_nvidia_gpu.bar_wait + // CHECK: tt.store + // CHECK: triton_nvidia_gpu.bar_arrive + tt.func public @matmal_from_wsmutex(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) { + %0 = triton_gpu.alloc_tensor : tensor<3x64x16xf16, #shared> + %1 = triton_gpu.alloc_tensor : tensor<3x16x64xf16, #shared1> + %2 = triton_nvidia_gpu.create_token {num = 3 : i32} : tensor<3x!triton_nvidia_gpu.token> + %3 = triton_nvidia_gpu.get_agent_id : i32 + %c0_i32 = arith.constant 0 : i32 + %4 = arith.cmpi eq, %3, %c0_i32 : i32 + scf.if %4 { + %cst = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<16> : tensor<16x64xi32, #blocked> + %cst_0 = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<16> : tensor<64x16xi32, #blocked1> + %c63_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 63 : i32 + %c114_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 114 : i32 + %c16_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 16 : i32 + %c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %c64_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 64 : i32 + %6 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>, axis = 0 : i32} : i32 + %7 = arith.addi %arg3, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %8 = arith.divsi %7, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %9 = arith.addi %arg4, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %10 = arith.divsi %9, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %11 = arith.muli %8, %10 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %13 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %14 = tt.splat %arg3 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %15 = tt.splat %arg4 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %16 = tt.splat %arg6 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64x1xi32, #blocked1> + %17 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %18 = tt.expand_dims %17 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x16xi32, #blocked1> + %19 = tt.broadcast %18 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x16xi32, #blocked1>) -> tensor<64x16xi32, #blocked1> + %20 = tt.splat %arg0 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr) -> tensor<64x16x!tt.ptr, #blocked1> + %21 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %22 = tt.expand_dims %21 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16x1xi32, #blocked> + %23 = tt.splat %arg7 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<1x64xi32, #blocked> + %24 = tt.broadcast %22 {async_agent = dense<0> : vector<1xi32>} : (tensor<16x1xi32, #blocked>) -> tensor<16x64xi32, #blocked> + %25 = tt.splat %arg1 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr) -> tensor<16x64x!tt.ptr, #blocked> + %c3_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 3 : i32 + %c0_i32_2 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %26 = scf.for %arg9 = %6 to %11 step %c114_i32 iter_args(%arg10 = %c0_i32_2) -> (i32) : i32 { + %27 = arith.divsi %arg9, %10 {async_agent = dense<0> : vector<1xi32>} : i32 + %28 = arith.remsi %arg9, %10 {async_agent = dense<0> : vector<1xi32>} : i32 + %29 = arith.muli %27, %c64_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %30 = tt.splat %29 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %31 = arith.addi %30, %12 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %32 = arith.remsi %31, %14 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %33 = arith.muli %28, %c64_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %34 = tt.splat %33 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %35 = arith.addi %34, %13 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %36 = arith.remsi %35, %15 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %37 = tt.expand_dims %32 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1> + %38 = arith.muli %37, %16 {async_agent = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %39 = tt.broadcast %38 {async_agent = dense<0> : vector<1xi32>} : (tensor<64x1xi32, #blocked1>) -> tensor<64x16xi32, #blocked1> + %40 = arith.addi %39, %19 {async_agent = dense<0> : vector<1xi32>} : tensor<64x16xi32, #blocked1> + %41 = tt.addptr %20, %40 {async_agent = dense<0> : vector<1xi32>} : tensor<64x16x!tt.ptr, #blocked1>, tensor<64x16xi32, #blocked1> + %42 = tt.expand_dims %36 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x64xi32, #blocked> + %43 = arith.muli %42, %23 {async_agent = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> + %44 = tt.broadcast %43 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x64xi32, #blocked>) -> tensor<16x64xi32, #blocked> + %45 = arith.addi %24, %44 {async_agent = dense<0> : vector<1xi32>} : tensor<16x64xi32, #blocked> + %46 = tt.addptr %25, %45 {async_agent = dense<0> : vector<1xi32>} : tensor<16x64x!tt.ptr, #blocked>, tensor<16x64xi32, #blocked> + %c3_i32_3 = arith.constant {async_agent = dense<0> : vector<1xi32>} 3 : i32 + %47 = arith.subi %arg5, %c0_i32_1 {async_agent = dense<0> : vector<1xi32>} : i32 + %48 = arith.divui %47, %c16_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %49 = arith.muli %arg10, %48 {async_agent = dense<0> : vector<1xi32>} : i32 + %c3_i32_4 = arith.constant {async_agent = dense<0> : vector<1xi32>} 3 : i32 + %50:3 = scf.for %arg11 = %c0_i32_1 to %arg5 step %c16_i32 iter_args(%arg12 = %41, %arg13 = %46, %arg14 = %49) -> (tensor<64x16x!tt.ptr, #blocked1>, tensor<16x64x!tt.ptr, #blocked>, i32) : i32 { + %52 = arith.remsi %arg14, %c3_i32_4 {async_agent = dense<0> : vector<1xi32>} : i32 + triton_nvidia_gpu.producer_acquire %2, %52 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + %53 = triton_gpu.insert_slice %arg12, %0, %52 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16x!tt.ptr, #blocked1> -> tensor<3x64x16xf16, #shared> + %54 = triton_gpu.insert_slice %arg13, %1, %52 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr, #blocked> -> tensor<3x16x64xf16, #shared1> + triton_nvidia_gpu.producer_commit %2, %52 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + %55 = tt.addptr %arg12, %cst_0 {async_agent = dense<0> : vector<1xi32>} : tensor<64x16x!tt.ptr, #blocked1>, tensor<64x16xi32, #blocked1> + %56 = tt.addptr %arg13, %cst {async_agent = dense<0> : vector<1xi32>} : tensor<16x64x!tt.ptr, #blocked>, tensor<16x64xi32, #blocked> + %c1_i32_6 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32 + %57 = arith.addi %arg14, %c1_i32_6 {async_agent = dense<0> : vector<1xi32>} : i32 + scf.yield {async_agent = dense<0> : vector<1xi32>} %55, %56, %57 : tensor<64x16x!tt.ptr, #blocked1>, tensor<16x64x!tt.ptr, #blocked>, i32 + } {async_agent = dense<0> : vector<1xi32>} + %c1_i32_5 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32 + %51 = arith.addi %arg10, %c1_i32_5 {async_agent = dense<0> : vector<1xi32>} : i32 + scf.yield {async_agent = dense<0> : vector<1xi32>} %51 : i32 + } {async_agent = dense<0> : vector<1xi32>} + } {async_agent = dense<0> : vector<1xi32>} + %c1_i32 = arith.constant 1 : i32 + %5 = arith.cmpi eq, %3, %c1_i32 : i32 + scf.if %5 { + %c0_i32_0 = arith.constant 0 : i32 + %6 = triton_nvidia_gpu.get_mutex_role_id {async_agent = dense<1> : vector<1xi32>, num = 2 : i32} : i32 + %7 = arith.cmpi ne, %6, %c0_i32_0 : i32 + %8 = triton_nvidia_gpu.create_mutex {async_agent = dense<1> : vector<1xi32>} : !triton_nvidia_gpu.mutex + %9 = triton_nvidia_gpu.create_mutex {async_agent = dense<1> : vector<1xi32>} : !triton_nvidia_gpu.mutex + %cst = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<64x64xf32, #mma> + %c63_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 63 : i32 + %c114_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 114 : i32 + %c16_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 16 : i32 + %c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %c64_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 64 : i32 + %10 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>, axis = 0 : i32} : i32 + %11 = arith.addi %arg3, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = arith.divsi %11, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %13 = arith.addi %arg4, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %14 = arith.divsi %13, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %15 = arith.muli %12, %14 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %16 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %17 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %18 = tt.splat %arg8 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<64x1xi32, #blocked2> + %19 = tt.splat %arg2 {async_agent = dense<1> : vector<1xi32>} : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked2> + %c3_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 3 : i32 + %c0_i32_2 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %20 = arith.muli %c114_i32, %6 {async_agent = dense<1> : vector<1xi32>} : i32 + %21 = arith.addi %10, %20 {async_agent = dense<1> : vector<1xi32>} : i32 + %c2_i32 = arith.constant {async_agent = dense<1> : vector<1xi32>} 2 : i32 + %22 = arith.muli %c114_i32, %c2_i32 {async_agent = dense<1> : vector<1xi32>} : i32 + %23 = arith.addi %c0_i32_2, %6 {async_agent = dense<1> : vector<1xi32>} : i32 + %24 = scf.for %arg9 = %21 to %15 step %22 iter_args(%arg10 = %23) -> (i32) : i32 { + %25 = arith.cmpi ne, %arg9, %10 : i32 + %26 = arith.ori %25, %7 {agent.mutex_role = 0 : i32} : i1 + scf.if %26 { + triton_nvidia_gpu.lock %8 {agent.mutex_role = 0 : i32} : !triton_nvidia_gpu.mutex + } {agent.mutex_role = 0 : i32} + %27 = arith.divsi %arg9, %14 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %28 = arith.remsi %arg9, %14 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %29 = arith.muli %27, %c64_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %30 = tt.splat %29 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %31 = arith.addi %30, %17 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %32 = arith.muli %28, %c64_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %33 = tt.splat %32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %34 = arith.addi %33, %16 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %c3_i32_3 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 3 : i32 + %35 = arith.subi %arg5, %c0_i32_1 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %36 = arith.divui %35, %c16_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %37 = arith.muli %arg10, %36 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %c3_i32_4 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 3 : i32 + %38:2 = scf.for %arg11 = %c0_i32_1 to %arg5 step %c16_i32 iter_args(%arg12 = %cst, %arg13 = %37) -> (tensor<64x64xf32, #mma>, i32) : i32 { + %48 = arith.remsi %arg13, %c3_i32_4 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + triton_nvidia_gpu.consumer_wait %2, %48 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + %49 = triton_gpu.extract_slice %0[%48, 0, 0] [1, 64, 16] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x64x16xf16, #shared> to tensor<64x16xf16, #shared> + %50 = triton_gpu.convert_layout %49 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x16xf16, #shared>) -> tensor<64x16xf16, #shared> + %51 = triton_gpu.extract_slice %1[%48, 0, 0] [1, 16, 64] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x16x64xf16, #shared1> to tensor<16x64xf16, #shared1> + %52 = triton_gpu.convert_layout %51 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<16x64xf16, #shared1>) -> tensor<16x64xf16, #shared1> + %53 = tt.dot %50, %52, %arg12 {agent.mutex_role = 0 : i32, allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + triton_nvidia_gpu.consumer_release %2, %48 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + %c1_i32_6 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32 + %54 = arith.addi %arg13, %c1_i32_6 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + scf.yield {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} %53, %54 : tensor<64x64xf32, #mma>, i32 + } {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} + triton_nvidia_gpu.unlock %8 : !triton_nvidia_gpu.mutex + scf.if %26 { + triton_nvidia_gpu.lock %9 {agent.mutex_role = 1 : i32} : !triton_nvidia_gpu.mutex + } {agent.mutex_role = 1 : i32} + %39 = tt.expand_dims %31 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64x1xi32, #blocked2> + %40 = arith.muli %39, %18 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked2> + %41 = tt.addptr %19, %40 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked2>, tensor<64x1xi32, #blocked2> + %42 = tt.expand_dims %34 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2> + %43 = tt.broadcast %41 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x1x!tt.ptr, #blocked2>) -> tensor<64x64x!tt.ptr, #blocked2> + %44 = tt.broadcast %42 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> + %45 = tt.addptr %43, %44 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> + %46 = triton_gpu.convert_layout %38#0 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked2> + tt.store %45, %46 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32} : tensor<64x64xf32, #blocked2> + triton_nvidia_gpu.unlock %9 : !triton_nvidia_gpu.mutex + %c1_i32_5 = arith.constant {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32 + %47 = arith.addi %arg10, %c2_i32 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + scf.yield {async_agent = dense<1> : vector<1xi32>} %47 : i32 + } {async_agent = dense<1> : vector<1xi32>} + } {"agent.num-roles" = 2 : i32, async_agent = dense<1> : vector<1xi32>} + tt.return + } +} diff --git a/test/TritonGPU/wsmutex.mlir b/test/TritonGPU/wsmutex.mlir new file mode 100644 index 000000000000..78b9037c51fa --- /dev/null +++ b/test/TritonGPU/wsmutex.mlir @@ -0,0 +1,166 @@ +// RUN: triton-opt -triton-nvidia-gpu-ws-mutex='compute-capability=90' %s | FileCheck %s + + +// CHECK: scf.if +// CHECK: scf.for +// CHECK: triton_nvidia_gpu.create_mutex +// CHECK: triton_nvidia_gpu.create_mutex +// CHECK: triton_nvidia_gpu.lock +// CHECK: agent.mutex_role = 0 : i32 +// CHECK: triton_nvidia_gpu.unlock +// CHECK: triton_nvidia_gpu.lock +// CHECK: agent.mutex_role = 1 : i32 +// CHECK: triton_nvidia_gpu.unlock + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability" = 90 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + tt.func public @static_persistent_warp_specialized_matmul_kernel_0d1d2d3de4de5de6de7c8c9de10de11c(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) { + %0 = triton_gpu.alloc_tensor : tensor<3x64x16xf16, #shared> + %1 = triton_gpu.alloc_tensor : tensor<3x16x64xf16, #shared1> + %2 = triton_nvidia_gpu.create_token {num = 3 : i32} : tensor<3x!triton_nvidia_gpu.token> + %3 = triton_nvidia_gpu.get_agent_id : i32 + %c0_i32 = arith.constant 0 : i32 + %4 = arith.cmpi eq, %3, %c0_i32 : i32 + scf.if %4 { + %cst = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<16> : tensor<16x64xi32, #blocked> + %cst_0 = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<16> : tensor<64x16xi32, #blocked1> + %c63_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 63 : i32 + %c114_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 114 : i32 + %c16_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 16 : i32 + %c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %c64_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 64 : i32 + %6 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>, axis = 0 : i32} : i32 + %7 = arith.addi %arg3, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %8 = arith.divsi %7, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %9 = arith.addi %arg4, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %10 = arith.divsi %9, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %11 = arith.muli %8, %10 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %13 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %14 = tt.splat %arg3 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %15 = tt.splat %arg4 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %16 = tt.splat %arg6 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64x1xi32, #blocked1> + %17 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %18 = tt.expand_dims %17 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x16xi32, #blocked1> + %19 = tt.broadcast %18 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x16xi32, #blocked1>) -> tensor<64x16xi32, #blocked1> + %20 = tt.splat %arg0 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr) -> tensor<64x16x!tt.ptr, #blocked1> + %21 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %22 = tt.expand_dims %21 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16x1xi32, #blocked> + %23 = tt.splat %arg7 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<1x64xi32, #blocked> + %24 = tt.broadcast %22 {async_agent = dense<0> : vector<1xi32>} : (tensor<16x1xi32, #blocked>) -> tensor<16x64xi32, #blocked> + %25 = tt.splat %arg1 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr) -> tensor<16x64x!tt.ptr, #blocked> + %c3_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 3 : i32 + %c0_i32_2 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %26 = scf.for %arg9 = %6 to %11 step %c114_i32 iter_args(%arg10 = %c0_i32_2) -> (i32) : i32 { + %27 = arith.divsi %arg9, %10 {async_agent = dense<0> : vector<1xi32>} : i32 + %28 = arith.remsi %arg9, %10 {async_agent = dense<0> : vector<1xi32>} : i32 + %29 = arith.muli %27, %c64_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %30 = tt.splat %29 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %31 = arith.addi %30, %12 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %32 = arith.remsi %31, %14 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %33 = arith.muli %28, %c64_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %34 = tt.splat %33 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %35 = arith.addi %34, %13 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %36 = arith.remsi %35, %15 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %37 = tt.expand_dims %32 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1> + %38 = arith.muli %37, %16 {async_agent = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %39 = tt.broadcast %38 {async_agent = dense<0> : vector<1xi32>} : (tensor<64x1xi32, #blocked1>) -> tensor<64x16xi32, #blocked1> + %40 = arith.addi %39, %19 {async_agent = dense<0> : vector<1xi32>} : tensor<64x16xi32, #blocked1> + %41 = tt.addptr %20, %40 {async_agent = dense<0> : vector<1xi32>} : tensor<64x16x!tt.ptr, #blocked1>, tensor<64x16xi32, #blocked1> + %42 = tt.expand_dims %36 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x64xi32, #blocked> + %43 = arith.muli %42, %23 {async_agent = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> + %44 = tt.broadcast %43 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x64xi32, #blocked>) -> tensor<16x64xi32, #blocked> + %45 = arith.addi %24, %44 {async_agent = dense<0> : vector<1xi32>} : tensor<16x64xi32, #blocked> + %46 = tt.addptr %25, %45 {async_agent = dense<0> : vector<1xi32>} : tensor<16x64x!tt.ptr, #blocked>, tensor<16x64xi32, #blocked> + %c3_i32_3 = arith.constant {async_agent = dense<0> : vector<1xi32>} 3 : i32 + %47 = arith.subi %arg5, %c0_i32_1 {async_agent = dense<0> : vector<1xi32>} : i32 + %48 = arith.divui %47, %c16_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %49 = arith.muli %arg10, %48 {async_agent = dense<0> : vector<1xi32>} : i32 + %c3_i32_4 = arith.constant {async_agent = dense<0> : vector<1xi32>} 3 : i32 + %50:3 = scf.for %arg11 = %c0_i32_1 to %arg5 step %c16_i32 iter_args(%arg12 = %41, %arg13 = %46, %arg14 = %49) -> (tensor<64x16x!tt.ptr, #blocked1>, tensor<16x64x!tt.ptr, #blocked>, i32) : i32 { + %52 = arith.remsi %arg14, %c3_i32_4 {async_agent = dense<0> : vector<1xi32>} : i32 + triton_nvidia_gpu.producer_acquire %2, %52 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + %53 = triton_gpu.insert_slice %arg12, %0, %52 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16x!tt.ptr, #blocked1> -> tensor<3x64x16xf16, #shared> + %54 = triton_gpu.insert_slice %arg13, %1, %52 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr, #blocked> -> tensor<3x16x64xf16, #shared1> + triton_nvidia_gpu.producer_commit %2, %52 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + %55 = tt.addptr %arg12, %cst_0 {async_agent = dense<0> : vector<1xi32>} : tensor<64x16x!tt.ptr, #blocked1>, tensor<64x16xi32, #blocked1> + %56 = tt.addptr %arg13, %cst {async_agent = dense<0> : vector<1xi32>} : tensor<16x64x!tt.ptr, #blocked>, tensor<16x64xi32, #blocked> + %c1_i32_6 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32 + %57 = arith.addi %arg14, %c1_i32_6 {async_agent = dense<0> : vector<1xi32>} : i32 + scf.yield {async_agent = dense<0> : vector<1xi32>} %55, %56, %57 : tensor<64x16x!tt.ptr, #blocked1>, tensor<16x64x!tt.ptr, #blocked>, i32 + } {async_agent = dense<0> : vector<1xi32>} + %c1_i32_5 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32 + %51 = arith.addi %arg10, %c1_i32_5 {async_agent = dense<0> : vector<1xi32>} : i32 + scf.yield {async_agent = dense<0> : vector<1xi32>} %51 : i32 + } {async_agent = dense<0> : vector<1xi32>} + } {async_agent = dense<0> : vector<1xi32>} + %c1_i32 = arith.constant 1 : i32 + %5 = arith.cmpi eq, %3, %c1_i32 : i32 + scf.if %5 { + %cst = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<64x64xf32, #mma> + %c63_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 63 : i32 + %c114_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 114 : i32 + %c16_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 16 : i32 + %c0_i32_0 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %c64_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 64 : i32 + %6 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>, axis = 0 : i32} : i32 + %7 = arith.addi %arg3, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %8 = arith.divsi %7, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %9 = arith.addi %arg4, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %10 = arith.divsi %9, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %11 = arith.muli %8, %10 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %13 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %14 = tt.splat %arg8 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<64x1xi32, #blocked2> + %15 = tt.splat %arg2 {async_agent = dense<1> : vector<1xi32>} : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked2> + %c3_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 3 : i32 + %c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %16 = scf.for %arg9 = %6 to %11 step %c114_i32 iter_args(%arg10 = %c0_i32_1) -> (i32) : i32 { + %17 = arith.divsi %arg9, %10 {async_agent = dense<1> : vector<1xi32>} : i32 + %18 = arith.remsi %arg9, %10 {async_agent = dense<1> : vector<1xi32>} : i32 + %19 = arith.muli %17, %c64_i32 {async_agent = dense<1> : vector<1xi32>} : i32 + %20 = tt.splat %19 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %21 = arith.addi %20, %13 {async_agent = dense<1> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %22 = arith.muli %18, %c64_i32 {async_agent = dense<1> : vector<1xi32>} : i32 + %23 = tt.splat %22 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %24 = arith.addi %23, %12 {async_agent = dense<1> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %c3_i32_2 = arith.constant {async_agent = dense<1> : vector<1xi32>} 3 : i32 + %25 = arith.subi %arg5, %c0_i32_0 {async_agent = dense<1> : vector<1xi32>} : i32 + %26 = arith.divui %25, %c16_i32 {async_agent = dense<1> : vector<1xi32>} : i32 + %27 = arith.muli %arg10, %26 {async_agent = dense<1> : vector<1xi32>} : i32 + %c3_i32_3 = arith.constant {async_agent = dense<1> : vector<1xi32>} 3 : i32 + %28:2 = scf.for %arg11 = %c0_i32_0 to %arg5 step %c16_i32 iter_args(%arg12 = %cst, %arg13 = %27) -> (tensor<64x64xf32, #mma>, i32) : i32 { + %38 = arith.remsi %arg13, %c3_i32_3 {async_agent = dense<1> : vector<1xi32>} : i32 + triton_nvidia_gpu.consumer_wait %2, %38 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + %39 = triton_gpu.extract_slice %0[%38, 0, 0] [1, 64, 16] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x64x16xf16, #shared> to tensor<64x16xf16, #shared> + %40 = triton_gpu.convert_layout %39 {async_agent = dense<1> : vector<1xi32>} : (tensor<64x16xf16, #shared>) -> tensor<64x16xf16, #shared> + %41 = triton_gpu.extract_slice %1[%38, 0, 0] [1, 16, 64] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x16x64xf16, #shared1> to tensor<16x64xf16, #shared1> + %42 = triton_gpu.convert_layout %41 {async_agent = dense<1> : vector<1xi32>} : (tensor<16x64xf16, #shared1>) -> tensor<16x64xf16, #shared1> + %43 = tt.dot %40, %42, %arg12 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + triton_nvidia_gpu.consumer_release %2, %38 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + %c1_i32_5 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32 + %44 = arith.addi %arg13, %c1_i32_5 {async_agent = dense<1> : vector<1xi32>} : i32 + scf.yield {async_agent = dense<1> : vector<1xi32>} %43, %44 : tensor<64x64xf32, #mma>, i32 + } {async_agent = dense<1> : vector<1xi32>} + %29 = tt.expand_dims %21 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64x1xi32, #blocked2> + %30 = arith.muli %29, %14 {async_agent = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked2> + %31 = tt.addptr %15, %30 {async_agent = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked2>, tensor<64x1xi32, #blocked2> + %32 = tt.expand_dims %24 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2> + %33 = tt.broadcast %31 {async_agent = dense<1> : vector<1xi32>} : (tensor<64x1x!tt.ptr, #blocked2>) -> tensor<64x64x!tt.ptr, #blocked2> + %34 = tt.broadcast %32 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> + %35 = tt.addptr %33, %34 {async_agent = dense<1> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> + %36 = triton_gpu.convert_layout %28#0 {async_agent = dense<1> : vector<1xi32>} : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked2> + tt.store %35, %36 {async_agent = dense<1> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32} : tensor<64x64xf32, #blocked2> + %c1_i32_4 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32 + %37 = arith.addi %arg10, %c1_i32_4 {async_agent = dense<1> : vector<1xi32>} : i32 + scf.yield {async_agent = dense<1> : vector<1xi32>} %37 : i32 + } {async_agent = dense<1> : vector<1xi32>} + } {async_agent = dense<1> : vector<1xi32>} + tt.return + } +} diff --git a/test/TritonGPU/wspipeline.mlir b/test/TritonGPU/wspipeline.mlir new file mode 100644 index 000000000000..c2b0a1b70813 --- /dev/null +++ b/test/TritonGPU/wspipeline.mlir @@ -0,0 +1,148 @@ +// RUN: triton-opt %s --triton-nvidia-gpu-ws-decomposing='compute-capability=90' --triton-nvidia-gpu-ws-pipeline='compute-capability=90' | FileCheck %s + +// CHECK: triton_gpu.alloc_tensor +// CHECK: triton_gpu.alloc_tensor +// CHECK: triton_nvidia_gpu.create_token +// CHECK: triton_nvidia_gpu.get_agent_id + +// CHECK: arith.cmpi eq +// CHECK: scf.if +// CHECK: scf.for +// CHECK: triton_nvidia_gpu.producer_acquire +// CHECK: triton_gpu.insert_slice +// CHECK: triton_gpu.insert_slice +// CHECK: triton_nvidia_gpu.producer_commit +// CHECK: scf.yield +// CHECK: async_agent = dense<0> : vector<1xi32> + +// CHECK: arith.cmpi eq +// CHECK: scf.if +// CHECK: scf.for +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: triton_gpu.extract_slice +// CHECK: triton_gpu.extract_slice +// CHECK: tt.dot +// CHECK: triton_nvidia_gpu.consumer_release +// CHECK: scf.yield +// CHECK: async_agent = dense<1> : vector<1xi32> + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32} { + tt.func public @_kernel_0d1d2d3d4d5d6d7c8c9d10d11c(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_0 = arith.constant dense<32> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant dense<32> : tensor<128x32xi32, #blocked1> + %c31_i32 = arith.constant 31 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c32_i32 = arith.constant 32 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.addi %arg3, %c127_i32 : i32 + %3 = arith.divsi %2, %c128_i32 : i32 + %4 = arith.addi %arg4, %c127_i32 : i32 + %5 = arith.divsi %4, %c128_i32 : i32 + %6 = arith.muli %5, %c8_i32 : i32 + %7 = arith.divsi %0, %6 : i32 + %8 = arith.muli %7, %c8_i32 : i32 + %9 = arith.subi %3, %8 : i32 + %10 = arith.cmpi slt, %9, %c8_i32 : i32 + %11 = arith.select %10, %9, %c8_i32 : i32 + %12 = arith.remsi %0, %11 : i32 + %13 = arith.addi %8, %12 : i32 + %14 = arith.remsi %0, %6 : i32 + %15 = arith.divsi %14, %11 : i32 + %16 = arith.muli %13, %c128_i32 : i32 + %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %23 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %24 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %25 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %26 = arith.addi %23, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.addi %24, %19 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %28 = arith.addi %25, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %29 = arith.muli %15, %c128_i32 : i32 + %30 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %31 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %32 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %33 = arith.addi %30, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %34 = arith.addi %31, %20 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %35 = arith.addi %32, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %36 = tt.splat %arg3 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %37 = tt.splat %arg3 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %38 = arith.remsi %26, %36 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %39 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %40 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %41 = arith.remsi %33, %39 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %42 = arith.muli %1, %c32_i32 : i32 + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %45 = tt.splat %42 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %46 = tt.splat %42 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %47 = arith.addi %45, %43 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %48 = arith.addi %46, %44 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %49 = tt.expand_dims %38 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> + %50 = tt.splat %arg6 : (i32) -> tensor<128x1xi32, #blocked1> + %51 = arith.muli %49, %50 : tensor<128x1xi32, #blocked1> + %52 = tt.expand_dims %47 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x32xi32, #blocked1> + %53 = tt.broadcast %51 : (tensor<128x1xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + %54 = tt.broadcast %52 : (tensor<1x32xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + %55 = arith.addi %53, %54 : tensor<128x32xi32, #blocked1> + %56 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x32x!tt.ptr, #blocked1> + %57 = tt.addptr %56, %55 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %58 = tt.expand_dims %48 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked> + %59 = tt.expand_dims %41 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %60 = tt.splat %arg7 : (i32) -> tensor<1x128xi32, #blocked> + %61 = arith.muli %59, %60 : tensor<1x128xi32, #blocked> + %62 = tt.broadcast %58 : (tensor<32x1xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %63 = tt.broadcast %61 : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %64 = arith.addi %62, %63 : tensor<32x128xi32, #blocked> + %65 = tt.splat %arg1 : (!tt.ptr) -> tensor<32x128x!tt.ptr, #blocked> + %66 = tt.addptr %65, %64 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %67 = arith.addi %arg5, %c31_i32 : i32 + %68 = arith.divsi %67, %c32_i32 : i32 + %69 = arith.index_cast %68 : i32 to index + %70:3 = scf.for %arg9 = %c0 to %69 step %c1 iter_args(%arg10 = %cst, %arg11 = %57, %arg12 = %66) -> (tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>) { + %89 = tt.load %arg11 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #blocked1> + %90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> + %91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> + %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> + %93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> + } + %71 = arith.truncf %70#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %72 = tt.expand_dims %27 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> + %73 = tt.splat %arg8 : (i32) -> tensor<128x1xi32, #blocked2> + %74 = arith.muli %72, %73 : tensor<128x1xi32, #blocked2> + %75 = tt.expand_dims %34 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2> + %76 = tt.broadcast %74 : (tensor<128x1xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %77 = tt.broadcast %75 : (tensor<1x128xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %78 = arith.addi %76, %77 : tensor<128x128xi32, #blocked2> + %79 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> + %80 = tt.addptr %79, %78 : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> + %81 = "triton_gpu.cmpi"(%28, %37) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %82 = tt.expand_dims %81 {axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> + %83 = "triton_gpu.cmpi"(%35, %40) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %84 = tt.expand_dims %83 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> + %85 = tt.broadcast %82 : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + %86 = tt.broadcast %84 : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + %87 = arith.andi %85, %86 : tensor<128x128xi1, #blocked2> + %88 = triton_gpu.convert_layout %71 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #blocked2> + tt.store %80, %88, %87 {cache = 1 : i32, evict = 1 : i32} : tensor<128x128xf16, #blocked2> + tt.return + } +} diff --git a/test/TritonNvidiaGPU/ws-feasibility-checking.mlir b/test/TritonNvidiaGPU/ws-feasibility-checking.mlir new file mode 100644 index 000000000000..0eec6889f8f7 --- /dev/null +++ b/test/TritonNvidiaGPU/ws-feasibility-checking.mlir @@ -0,0 +1,1035 @@ +// RUN: triton-opt -split-input-file -triton-nvidia-gpu-ws-feasibility-checking='compute-capability=90' %s 2>&1 | FileCheck %s + +// Check if all opereations are labeled with appropriate attributes. +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +// CHECK: "triton_gpu.enable-warp-specialization" = 1 : i32 +// CHECK-LABEL @simple_gemm +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + tt.func public @simple_gemm(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_0 = arith.constant dense<32> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant dense<32> : tensor<128x32xi32, #blocked1> + %c31_i32 = arith.constant 31 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c32_i32 = arith.constant 32 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.addi %arg3, %c127_i32 : i32 + %3 = arith.divsi %2, %c128_i32 : i32 + %4 = arith.addi %arg4, %c127_i32 : i32 + %5 = arith.divsi %4, %c128_i32 : i32 + %6 = arith.muli %5, %c8_i32 : i32 + %7 = arith.divsi %0, %6 : i32 + %8 = arith.muli %7, %c8_i32 : i32 + %9 = arith.subi %3, %8 : i32 + %10 = arith.cmpi slt, %9, %c8_i32 : i32 + %11 = arith.select %10, %9, %c8_i32 : i32 + %12 = arith.remsi %0, %11 : i32 + %13 = arith.addi %8, %12 : i32 + %14 = arith.remsi %0, %6 : i32 + %15 = arith.divsi %14, %11 : i32 + %16 = arith.muli %13, %c128_i32 : i32 + %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %23 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %24 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %25 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %26 = arith.addi %23, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.addi %24, %19 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %28 = arith.addi %25, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %29 = arith.muli %15, %c128_i32 : i32 + %30 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %31 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %32 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %33 = arith.addi %30, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %34 = arith.addi %31, %20 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %35 = arith.addi %32, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %36 = tt.splat %arg3 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %37 = tt.splat %arg3 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %38 = arith.remsi %26, %36 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %39 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %40 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %41 = arith.remsi %33, %39 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %42 = arith.muli %1, %c32_i32 : i32 + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %45 = tt.splat %42 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %46 = tt.splat %42 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %47 = arith.addi %45, %43 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %48 = arith.addi %46, %44 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %49 = tt.expand_dims %38 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> + %50 = tt.splat %arg6 : (i32) -> tensor<128x1xi32, #blocked1> + %51 = arith.muli %49, %50 : tensor<128x1xi32, #blocked1> + %52 = tt.expand_dims %47 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x32xi32, #blocked1> + %53 = tt.broadcast %51 : (tensor<128x1xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + %54 = tt.broadcast %52 : (tensor<1x32xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + %55 = arith.addi %53, %54 : tensor<128x32xi32, #blocked1> + %56 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x32x!tt.ptr, #blocked1> + %57 = tt.addptr %56, %55 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %58 = tt.expand_dims %48 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked> + %59 = tt.expand_dims %41 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %60 = tt.splat %arg7 : (i32) -> tensor<1x128xi32, #blocked> + %61 = arith.muli %59, %60 : tensor<1x128xi32, #blocked> + %62 = tt.broadcast %58 : (tensor<32x1xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %63 = tt.broadcast %61 : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %64 = arith.addi %62, %63 : tensor<32x128xi32, #blocked> + %65 = tt.splat %arg1 : (!tt.ptr) -> tensor<32x128x!tt.ptr, #blocked> + %66 = tt.addptr %65, %64 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %67 = arith.addi %arg5, %c31_i32 : i32 + %68 = arith.divsi %67, %c32_i32 : i32 + %69 = arith.index_cast %68 : i32 to index + %70:3 = scf.for %arg9 = %c0 to %69 step %c1 iter_args(%arg10 = %cst, %arg11 = %57, %arg12 = %66) -> (tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>) { + %89 = tt.load %arg11 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #blocked1> + %90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> + %91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> + %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> + %93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> + } + %71 = arith.truncf %70#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %72 = tt.expand_dims %27 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> + %73 = tt.splat %arg8 : (i32) -> tensor<128x1xi32, #blocked2> + %74 = arith.muli %72, %73 : tensor<128x1xi32, #blocked2> + %75 = tt.expand_dims %34 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2> + %76 = tt.broadcast %74 : (tensor<128x1xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %77 = tt.broadcast %75 : (tensor<1x128xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %78 = arith.addi %76, %77 : tensor<128x128xi32, #blocked2> + %79 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> + %80 = tt.addptr %79, %78 : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> + %81 = "triton_gpu.cmpi"(%28, %37) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %82 = tt.expand_dims %81 {axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> + %83 = "triton_gpu.cmpi"(%35, %40) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %84 = tt.expand_dims %83 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> + %85 = tt.broadcast %82 : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + %86 = tt.broadcast %84 : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + %87 = arith.andi %85, %86 : tensor<128x128xi1, #blocked2> + %88 = triton_gpu.convert_layout %71 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #blocked2> + tt.store %80, %88, %87 {cache = 1 : i32, evict = 1 : i32} : tensor<128x128xf16, #blocked2> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK: "triton_gpu.enable-warp-specialization" = 1 : i32 + // CHECK-LABEL: @nested_for_gemm + tt.func public @nested_for_gemm(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_0 = arith.constant dense<32> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant dense<32> : tensor<128x32xi32, #blocked1> + %c31_i32 = arith.constant 31 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c32_i32 = arith.constant 32 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.addi %arg3, %c127_i32 : i32 + %3 = arith.divsi %2, %c128_i32 : i32 + %4 = arith.addi %arg4, %c127_i32 : i32 + %5 = arith.divsi %4, %c128_i32 : i32 + %6 = arith.muli %5, %c8_i32 : i32 + %7 = arith.divsi %0, %6 : i32 + %8 = arith.muli %7, %c8_i32 : i32 + %9 = arith.subi %3, %8 : i32 + %10 = arith.cmpi slt, %9, %c8_i32 : i32 + %11 = arith.select %10, %9, %c8_i32 : i32 + %12 = arith.remsi %0, %11 : i32 + %13 = arith.addi %8, %12 : i32 + %14 = arith.remsi %0, %6 : i32 + %15 = arith.divsi %14, %11 : i32 + %16 = arith.muli %13, %c128_i32 : i32 + %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %23 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %24 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %25 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %26 = arith.addi %23, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.addi %24, %19 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %28 = arith.addi %25, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %29 = arith.muli %15, %c128_i32 : i32 + %30 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %31 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %32 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %33 = arith.addi %30, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %34 = arith.addi %31, %20 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %35 = arith.addi %32, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %36 = tt.splat %arg3 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %37 = tt.splat %arg3 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %38 = arith.remsi %26, %36 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %39 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %40 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %41 = arith.remsi %33, %39 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %42 = arith.muli %1, %c32_i32 : i32 + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %45 = tt.splat %42 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %46 = tt.splat %42 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %47 = arith.addi %45, %43 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %48 = arith.addi %46, %44 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %49 = tt.expand_dims %38 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> + %50 = tt.splat %arg6 : (i32) -> tensor<128x1xi32, #blocked1> + %51 = arith.muli %49, %50 : tensor<128x1xi32, #blocked1> + %52 = tt.expand_dims %47 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x32xi32, #blocked1> + %53 = tt.broadcast %51 : (tensor<128x1xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + %54 = tt.broadcast %52 : (tensor<1x32xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + %55 = arith.addi %53, %54 : tensor<128x32xi32, #blocked1> + %56 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x32x!tt.ptr, #blocked1> + %57 = tt.addptr %56, %55 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %58 = tt.expand_dims %48 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked> + %59 = tt.expand_dims %41 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %60 = tt.splat %arg7 : (i32) -> tensor<1x128xi32, #blocked> + %61 = arith.muli %59, %60 : tensor<1x128xi32, #blocked> + %62 = tt.broadcast %58 : (tensor<32x1xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %63 = tt.broadcast %61 : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %64 = arith.addi %62, %63 : tensor<32x128xi32, #blocked> + %65 = tt.splat %arg1 : (!tt.ptr) -> tensor<32x128x!tt.ptr, #blocked> + %66 = tt.addptr %65, %64 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %67 = arith.addi %arg5, %c31_i32 : i32 + %68 = arith.divsi %67, %c32_i32 : i32 + %69 = arith.index_cast %68 : i32 to index + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #shared> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #shared1> + %70:3 = scf.for %arg9 = %c0 to %69 step %c1 iter_args(%arg10 = %cst, %arg11 = %57, %arg12 = %66) -> (tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>) { + %96:2 = scf.for %arg13 = %c0 to %69 step %c1 iter_args(%arg14 = %cst_2, %arg15 = %cst_3) -> (tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1>) { + %89 = tt.load %arg11 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #blocked1> + %90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> + %91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> + %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> + scf.yield %91, %92 : tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1> + } + %93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> + } + %71 = arith.truncf %70#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %72 = tt.expand_dims %27 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> + %73 = tt.splat %arg8 : (i32) -> tensor<128x1xi32, #blocked2> + %74 = arith.muli %72, %73 : tensor<128x1xi32, #blocked2> + %75 = tt.expand_dims %34 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2> + %76 = tt.broadcast %74 : (tensor<128x1xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %77 = tt.broadcast %75 : (tensor<1x128xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %78 = arith.addi %76, %77 : tensor<128x128xi32, #blocked2> + %79 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> + %80 = tt.addptr %79, %78 : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> + %81 = "triton_gpu.cmpi"(%28, %37) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %82 = tt.expand_dims %81 {axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> + %83 = "triton_gpu.cmpi"(%35, %40) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %84 = tt.expand_dims %83 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> + %85 = tt.broadcast %82 : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + %86 = tt.broadcast %84 : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + %87 = arith.andi %85, %86 : tensor<128x128xi1, #blocked2> + %88 = triton_gpu.convert_layout %71 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #blocked2> + tt.store %80, %88, %87 {cache = 1 : i32, evict = 1 : i32} : tensor<128x128xf16, #blocked2> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK: "triton_gpu.enable-warp-specialization" = 1 : i32 + // CHECK-LABEL: @if_in_for_gemm + tt.func public @if_in_for_gemm(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_0 = arith.constant dense<32> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant dense<32> : tensor<128x32xi32, #blocked1> + %c31_i32 = arith.constant 31 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c32_i32 = arith.constant 32 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.addi %arg3, %c127_i32 : i32 + %3 = arith.divsi %2, %c128_i32 : i32 + %4 = arith.addi %arg4, %c127_i32 : i32 + %5 = arith.divsi %4, %c128_i32 : i32 + %6 = arith.muli %5, %c8_i32 : i32 + %7 = arith.divsi %0, %6 : i32 + %8 = arith.muli %7, %c8_i32 : i32 + %9 = arith.subi %3, %8 : i32 + %10 = arith.cmpi slt, %9, %c8_i32 : i32 + %11 = arith.select %10, %9, %c8_i32 : i32 + %12 = arith.remsi %0, %11 : i32 + %13 = arith.addi %8, %12 : i32 + %14 = arith.remsi %0, %6 : i32 + %15 = arith.divsi %14, %11 : i32 + %16 = arith.muli %13, %c128_i32 : i32 + %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %23 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %24 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %25 = tt.splat %16 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %26 = arith.addi %23, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.addi %24, %19 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %28 = arith.addi %25, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %29 = arith.muli %15, %c128_i32 : i32 + %30 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %31 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %32 = tt.splat %29 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %33 = arith.addi %30, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %34 = arith.addi %31, %20 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %35 = arith.addi %32, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %36 = tt.splat %arg3 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %37 = tt.splat %arg3 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %38 = arith.remsi %26, %36 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %39 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %40 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %41 = arith.remsi %33, %39 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %42 = arith.muli %1, %c32_i32 : i32 + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %45 = tt.splat %42 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %46 = tt.splat %42 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %47 = arith.addi %45, %43 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %48 = arith.addi %46, %44 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %49 = tt.expand_dims %38 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> + %50 = tt.splat %arg6 : (i32) -> tensor<128x1xi32, #blocked1> + %51 = arith.muli %49, %50 : tensor<128x1xi32, #blocked1> + %52 = tt.expand_dims %47 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x32xi32, #blocked1> + %53 = tt.broadcast %51 : (tensor<128x1xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + %54 = tt.broadcast %52 : (tensor<1x32xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> + %55 = arith.addi %53, %54 : tensor<128x32xi32, #blocked1> + %56 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x32x!tt.ptr, #blocked1> + %57 = tt.addptr %56, %55 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %58 = tt.expand_dims %48 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked> + %59 = tt.expand_dims %41 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %60 = tt.splat %arg7 : (i32) -> tensor<1x128xi32, #blocked> + %61 = arith.muli %59, %60 : tensor<1x128xi32, #blocked> + %62 = tt.broadcast %58 : (tensor<32x1xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %63 = tt.broadcast %61 : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %64 = arith.addi %62, %63 : tensor<32x128xi32, #blocked> + %65 = tt.splat %arg1 : (!tt.ptr) -> tensor<32x128x!tt.ptr, #blocked> + %66 = tt.addptr %65, %64 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %67 = arith.addi %arg5, %c31_i32 : i32 + %68 = arith.divsi %67, %c32_i32 : i32 + %69 = arith.index_cast %68 : i32 to index + %70:3 = scf.for %arg9 = %c0 to %69 step %c1 iter_args(%arg10 = %cst, %arg11 = %57, %arg12 = %66) -> (tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>) { + %arg9_i32 = arith.index_cast %arg9 : index to i32 + %96 = arith.cmpi ne, %arg9_i32, %c31_i32 : i32 + %89 = scf.if %96 -> (tensor<128x32xf16, #blocked1>) { + %r0_0 = arith.select %96, %c31_i32, %c127_i32 : i32 + %r0_1 = tt.splat %r0_0 : (i32) -> tensor<128x32xi32, #blocked1> + %new_addr = tt.addptr %arg11, %r0_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %new_89 = tt.load %new_addr {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #blocked1> + scf.yield %new_89 : tensor<128x32xf16, #blocked1> + } else { + %new_89 = tt.load %arg11 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #blocked1> + scf.yield %new_89 : tensor<128x32xf16, #blocked1> + } + %90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> + %91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> + %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> + %93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %base_94 = scf.if %96 -> (tensor<128x32x!tt.ptr, #blocked1>) { + %r1_0 = arith.select %96, %c31_i32, %c127_i32 : i32 + %r1_1 = tt.splat %r1_0 : (i32) -> tensor<128x32xi32, #blocked1> + %98 = tt.addptr %arg11, %r1_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + scf.yield %98 : tensor<128x32x!tt.ptr, #blocked1> + } else { + %98 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + scf.yield %98 : tensor<128x32x!tt.ptr, #blocked1> + } + %94 = tt.addptr %base_94, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> + } + %71 = arith.truncf %70#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %72 = tt.expand_dims %27 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> + %73 = tt.splat %arg8 : (i32) -> tensor<128x1xi32, #blocked2> + %74 = arith.muli %72, %73 : tensor<128x1xi32, #blocked2> + %75 = tt.expand_dims %34 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2> + %76 = tt.broadcast %74 : (tensor<128x1xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %77 = tt.broadcast %75 : (tensor<1x128xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %78 = arith.addi %76, %77 : tensor<128x128xi32, #blocked2> + %79 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> + %80 = tt.addptr %79, %78 : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> + %81 = "triton_gpu.cmpi"(%28, %37) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %82 = tt.expand_dims %81 {axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> + %83 = "triton_gpu.cmpi"(%35, %40) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %84 = tt.expand_dims %83 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> + %85 = tt.broadcast %82 : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + %86 = tt.broadcast %84 : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> + %87 = arith.andi %85, %86 : tensor<128x128xi1, #blocked2> + %88 = triton_gpu.convert_layout %71 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #blocked2> + tt.store %80, %88, %87 {cache = 1 : i32, evict = 1 : i32} : tensor<128x128xf16, #blocked2> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK: "triton_gpu.enable-warp-specialization" = 1 : i32 + // CHECK-LABEL: @tma_warp_specialized_matmul + tt.func public @tma_warp_specialized_matmul(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + %c63_i32 = arith.constant 63 : i32 + %c16_i32 = arith.constant 16 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg4, %c63_i32 : i32 + %2 = arith.divsi %1, %c64_i32 : i32 + %3 = arith.divsi %0, %2 : i32 + %4 = arith.remsi %0, %2 : i32 + %5 = arith.muli %3, %c64_i32 : i32 + %6 = arith.muli %4, %c64_i32 : i32 + %7 = arith.extsi %arg3 : i32 to i64 + %8 = arith.extsi %arg5 : i32 to i64 + %9 = arith.extsi %arg6 : i32 to i64 + %10 = tt.make_tensor_ptr %arg0, [%7, %8], [%9, %c1_i64], [%5, %c0_i32] {order = array} : , 1> + %11 = arith.extsi %arg4 : i32 to i64 + %12 = arith.extsi %arg7 : i32 to i64 + %13 = tt.make_tensor_ptr %arg1, [%8, %11], [%c1_i64, %12], [%c0_i32, %6] {order = array} : , 1> + %14:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c16_i32 iter_args(%arg10 = %cst, %arg11 = %10, %arg12 = %13) -> (tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1>) : i32 { + %46 = tt.load %arg11 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<64x16xf16, #blocked2> + %47 = tt.load %arg12 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x64xf16, #blocked3> + %48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared> + %49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1> + %50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %51 = tt.advance %arg11, [%c0_i32, %c16_i32] : , 1> + %52 = tt.advance %arg12, [%c16_i32, %c0_i32] : , 1> + scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> + } + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %19 = tt.splat %5 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %20 = tt.splat %5 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %21 = arith.addi %16, %19 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %22 = arith.addi %18, %20 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %23 = tt.splat %6 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %24 = tt.splat %6 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %25 = arith.addi %15, %23 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %26 = arith.addi %17, %24 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %27 = tt.expand_dims %21 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64x1xi32, #blocked4> + %28 = tt.splat %arg8 : (i32) -> tensor<64x1xi32, #blocked4> + %29 = arith.muli %27, %28 : tensor<64x1xi32, #blocked4> + %30 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked4> + %31 = tt.addptr %30, %29 : tensor<64x1x!tt.ptr, #blocked4>, tensor<64x1xi32, #blocked4> + %32 = tt.expand_dims %25 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<1x64xi32, #blocked4> + %33 = tt.broadcast %31 : (tensor<64x1x!tt.ptr, #blocked4>) -> tensor<64x64x!tt.ptr, #blocked4> + %34 = tt.broadcast %32 : (tensor<1x64xi32, #blocked4>) -> tensor<64x64xi32, #blocked4> + %35 = tt.addptr %33, %34 : tensor<64x64x!tt.ptr, #blocked4>, tensor<64x64xi32, #blocked4> + %36 = tt.splat %arg3 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %37 = "triton_gpu.cmpi"(%22, %36) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>, tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %38 = tt.expand_dims %37 {axis = 1 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64x1xi1, #blocked4> + %39 = tt.splat %arg4 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %40 = "triton_gpu.cmpi"(%26, %39) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>, tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %41 = tt.expand_dims %40 {axis = 0 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<1x64xi1, #blocked4> + %42 = tt.broadcast %38 : (tensor<64x1xi1, #blocked4>) -> tensor<64x64xi1, #blocked4> + %43 = tt.broadcast %41 : (tensor<1x64xi1, #blocked4>) -> tensor<64x64xi1, #blocked4> + %44 = arith.andi %42, %43 : tensor<64x64xi1, #blocked4> + %45 = triton_gpu.convert_layout %14#0 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked4> + tt.store %35, %45, %44 {cache = 1 : i32, evict = 1 : i32} : tensor<64x64xf32, #blocked4> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK: "triton_gpu.enable-warp-specialization" = 1 : i32 + // CHECK-LABEL: @store_after_load + tt.func public @store_after_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + %c63_i32 = arith.constant 63 : i32 + %c16_i32 = arith.constant 16 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg4, %c63_i32 : i32 + %2 = arith.divsi %1, %c64_i32 : i32 + %3 = arith.divsi %0, %2 : i32 + %4 = arith.remsi %0, %2 : i32 + %5 = arith.muli %3, %c64_i32 : i32 + %6 = arith.muli %4, %c64_i32 : i32 + %7 = arith.extsi %arg3 : i32 to i64 + %8 = arith.extsi %arg5 : i32 to i64 + %9 = arith.extsi %arg6 : i32 to i64 + %10 = tt.make_tensor_ptr %arg0, [%7, %8], [%9, %c1_i64], [%5, %c0_i32] {order = array} : , 1> + %11 = arith.extsi %arg4 : i32 to i64 + %12 = arith.extsi %arg7 : i32 to i64 + %13 = tt.make_tensor_ptr %arg1, [%8, %11], [%c1_i64, %12], [%c0_i32, %6] {order = array} : , 1> + %14:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c16_i32 iter_args(%arg10 = %cst, %arg11 = %10, %arg12 = %13) -> (tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1>) : i32 { + %46 = tt.load %arg11 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<64x16xf16, #blocked2> + %47 = tt.load %arg12 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x64xf16, #blocked3> + %48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared> + %49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1> + %50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %51 = tt.advance %arg11, [%c0_i32, %c16_i32] : , 1> + %52 = tt.advance %arg12, [%c16_i32, %c0_i32] : , 1> + scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> + } + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %19 = tt.splat %5 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %20 = tt.splat %5 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %21 = arith.addi %16, %19 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %22 = arith.addi %18, %20 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %23 = tt.splat %6 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %24 = tt.splat %6 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %25 = arith.addi %15, %23 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %26 = arith.addi %17, %24 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %27 = tt.expand_dims %21 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64x1xi32, #blocked4> + %28 = tt.splat %arg8 : (i32) -> tensor<64x1xi32, #blocked4> + %29 = arith.muli %27, %28 : tensor<64x1xi32, #blocked4> + %30 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked4> + %31 = tt.addptr %30, %29 : tensor<64x1x!tt.ptr, #blocked4>, tensor<64x1xi32, #blocked4> + %32 = tt.expand_dims %25 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<1x64xi32, #blocked4> + %33 = tt.broadcast %31 : (tensor<64x1x!tt.ptr, #blocked4>) -> tensor<64x64x!tt.ptr, #blocked4> + %34 = tt.broadcast %32 : (tensor<1x64xi32, #blocked4>) -> tensor<64x64xi32, #blocked4> + %35 = tt.addptr %33, %34 : tensor<64x64x!tt.ptr, #blocked4>, tensor<64x64xi32, #blocked4> + %36 = tt.splat %arg3 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %37 = "triton_gpu.cmpi"(%22, %36) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>, tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %38 = tt.expand_dims %37 {axis = 1 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64x1xi1, #blocked4> + %39 = tt.splat %arg4 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %40 = "triton_gpu.cmpi"(%26, %39) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>, tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %41 = tt.expand_dims %40 {axis = 0 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<1x64xi1, #blocked4> + %42 = tt.broadcast %38 : (tensor<64x1xi1, #blocked4>) -> tensor<64x64xi1, #blocked4> + %43 = tt.broadcast %41 : (tensor<1x64xi1, #blocked4>) -> tensor<64x64xi1, #blocked4> + %44 = arith.andi %42, %43 : tensor<64x64xi1, #blocked4> + %45 = triton_gpu.convert_layout %14#0 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked4> + %46 = tt.load %35, %44 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked4> + %47 = arith.addf %45, %46 : tensor<64x64xf32, #blocked4> + tt.store %35, %47, %44 {cache = 1 : i32, evict = 1 : i32} : tensor<64x64xf32, #blocked4> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK: "triton_gpu.enable-warp-specialization" = 0 : i32 + // CHECK-LABEL: @load_after_store + tt.func public @load_after_store(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + %c63_i32 = arith.constant 63 : i32 + %c16_i32 = arith.constant 16 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg4, %c63_i32 : i32 + %2 = arith.divsi %1, %c64_i32 : i32 + %3 = arith.divsi %0, %2 : i32 + %4 = arith.remsi %0, %2 : i32 + %5 = arith.muli %3, %c64_i32 : i32 + %6 = arith.muli %4, %c64_i32 : i32 + %7 = arith.extsi %arg3 : i32 to i64 + %8 = arith.extsi %arg5 : i32 to i64 + %9 = arith.extsi %arg6 : i32 to i64 + %10 = tt.make_tensor_ptr %arg0, [%7, %8], [%9, %c1_i64], [%5, %c0_i32] {order = array} : , 1> + %11 = arith.extsi %arg4 : i32 to i64 + %12 = arith.extsi %arg7 : i32 to i64 + %13 = tt.make_tensor_ptr %arg1, [%8, %11], [%c1_i64, %12], [%c0_i32, %6] {order = array} : , 1> + %14:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c16_i32 iter_args(%arg10 = %cst, %arg11 = %10, %arg12 = %13) -> (tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1>) : i32 { + %46 = tt.load %arg11 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<64x16xf16, #blocked2> + %47 = tt.load %arg12 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x64xf16, #blocked3> + %48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared> + %49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1> + %50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %51 = tt.advance %arg11, [%c0_i32, %c16_i32] : , 1> + %52 = tt.advance %arg12, [%c16_i32, %c0_i32] : , 1> + scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> + } + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %19 = tt.splat %5 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %20 = tt.splat %5 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %21 = arith.addi %16, %19 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %22 = arith.addi %18, %20 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %23 = tt.splat %6 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %24 = tt.splat %6 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %25 = arith.addi %15, %23 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %26 = arith.addi %17, %24 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %27 = tt.expand_dims %21 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64x1xi32, #blocked4> + %28 = tt.splat %arg8 : (i32) -> tensor<64x1xi32, #blocked4> + %29 = arith.muli %27, %28 : tensor<64x1xi32, #blocked4> + %30 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked4> + %31 = tt.addptr %30, %29 : tensor<64x1x!tt.ptr, #blocked4>, tensor<64x1xi32, #blocked4> + %32 = tt.expand_dims %25 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<1x64xi32, #blocked4> + %33 = tt.broadcast %31 : (tensor<64x1x!tt.ptr, #blocked4>) -> tensor<64x64x!tt.ptr, #blocked4> + %34 = tt.broadcast %32 : (tensor<1x64xi32, #blocked4>) -> tensor<64x64xi32, #blocked4> + %35 = tt.addptr %33, %34 : tensor<64x64x!tt.ptr, #blocked4>, tensor<64x64xi32, #blocked4> + %36 = tt.splat %arg3 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %37 = "triton_gpu.cmpi"(%22, %36) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>, tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %38 = tt.expand_dims %37 {axis = 1 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64x1xi1, #blocked4> + %39 = tt.splat %arg4 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %40 = "triton_gpu.cmpi"(%26, %39) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>, tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %41 = tt.expand_dims %40 {axis = 0 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<1x64xi1, #blocked4> + %42 = tt.broadcast %38 : (tensor<64x1xi1, #blocked4>) -> tensor<64x64xi1, #blocked4> + %43 = tt.broadcast %41 : (tensor<1x64xi1, #blocked4>) -> tensor<64x64xi1, #blocked4> + %44 = arith.andi %42, %43 : tensor<64x64xi1, #blocked4> + %45 = triton_gpu.convert_layout %14#0 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked4> + %46 = tt.load %35, %44 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked4> + %47 = arith.addf %45, %46 : tensor<64x64xf32, #blocked4> + tt.store %35, %47, %44 {cache = 1 : i32, evict = 1 : i32} : tensor<64x64xf32, #blocked4> + %48 = tt.load %35, %44 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked4> + %49 = arith.addf %45, %48 : tensor<64x64xf32, #blocked4> + tt.store %35, %49, %44 {cache = 1 : i32, evict = 1 : i32} : tensor<64x64xf32, #blocked4> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK: "triton_gpu.enable-warp-specialization" = 0 : i32 + // CHECK-LABEL: @global_bar + tt.func public @global_bar(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg100: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + %c63_i32 = arith.constant 63 : i32 + %c16_i32 = arith.constant 16 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c1_i64 = arith.constant 1 : i64 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg4, %c63_i32 : i32 + %2 = arith.divsi %1, %c64_i32 : i32 + %3 = arith.divsi %0, %2 : i32 + %4 = arith.remsi %0, %2 : i32 + %5 = arith.muli %3, %c64_i32 : i32 + %6 = arith.muli %4, %c64_i32 : i32 + %7 = arith.extsi %arg3 : i32 to i64 + %8 = arith.extsi %arg5 : i32 to i64 + %9 = arith.extsi %arg6 : i32 to i64 + %10 = tt.make_tensor_ptr %arg0, [%7, %8], [%9, %c1_i64], [%5, %c0_i32] {order = array} : , 1> + %11 = arith.extsi %arg4 : i32 to i64 + %12 = arith.extsi %arg7 : i32 to i64 + %13 = tt.make_tensor_ptr %arg1, [%8, %11], [%c1_i64, %12], [%c0_i32, %6] {order = array} : , 1> + %14:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c16_i32 iter_args(%arg10 = %cst, %arg11 = %10, %arg12 = %13) -> (tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1>) : i32 { + %46 = tt.load %arg11 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<64x16xf16, #blocked2> + %47 = tt.load %arg12 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x64xf16, #blocked3> + %48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared> + %49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1> + %50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %51 = tt.advance %arg11, [%c0_i32, %c16_i32] : , 1> + %52 = tt.advance %arg12, [%c16_i32, %c0_i32] : , 1> + scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> + } + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %19 = tt.splat %5 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %20 = tt.splat %5 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %21 = arith.addi %16, %19 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %22 = arith.addi %18, %20 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %23 = tt.splat %6 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %24 = tt.splat %6 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %25 = arith.addi %15, %23 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %26 = arith.addi %17, %24 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %27 = tt.expand_dims %21 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64x1xi32, #blocked4> + %28 = tt.splat %arg8 : (i32) -> tensor<64x1xi32, #blocked4> + %29 = arith.muli %27, %28 : tensor<64x1xi32, #blocked4> + %30 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked4> + %31 = tt.addptr %30, %29 : tensor<64x1x!tt.ptr, #blocked4>, tensor<64x1xi32, #blocked4> + %32 = tt.expand_dims %25 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<1x64xi32, #blocked4> + %33 = tt.broadcast %31 : (tensor<64x1x!tt.ptr, #blocked4>) -> tensor<64x64x!tt.ptr, #blocked4> + %34 = tt.broadcast %32 : (tensor<1x64xi32, #blocked4>) -> tensor<64x64xi32, #blocked4> + %35 = tt.addptr %33, %34 : tensor<64x64x!tt.ptr, #blocked4>, tensor<64x64xi32, #blocked4> + %36 = tt.splat %arg3 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %37 = "triton_gpu.cmpi"(%22, %36) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>, tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %38 = tt.expand_dims %37 {axis = 1 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64x1xi1, #blocked4> + %39 = tt.splat %arg4 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %40 = "triton_gpu.cmpi"(%26, %39) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>, tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %41 = tt.expand_dims %40 {axis = 0 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<1x64xi1, #blocked4> + %42 = tt.broadcast %38 : (tensor<64x1xi1, #blocked4>) -> tensor<64x64xi1, #blocked4> + %43 = tt.broadcast %41 : (tensor<1x64xi1, #blocked4>) -> tensor<64x64xi1, #blocked4> + %44 = arith.andi %42, %43 : tensor<64x64xi1, #blocked4> + %45 = triton_gpu.convert_layout %14#0 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked4> + "tt.atomic_cas"(%arg100, %c0_i32, %c1_i32) {sem = 1 : i32}: (!tt.ptr, i32, i32) -> i32 + %46 = tt.load %35, %44 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked4> + %47 = arith.addf %45, %46 : tensor<64x64xf32, #blocked4> + tt.store %35, %47, %44 {cache = 1 : i32, evict = 1 : i32} : tensor<64x64xf32, #blocked4> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked4 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK: "triton_gpu.enable-warp-specialization" = 1 : i32 + // CHECK-LABEL: @store_in_nested_for + tt.func public @store_in_nested_for(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma> + %c63_i32 = arith.constant 63 : i32 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c114_i32 = arith.constant 114 : i32 + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c255_i32 : i32 + %2 = arith.divsi %1, %c256_i32 : i32 + %3 = arith.addi %arg4, %c127_i32 : i32 + %4 = arith.divsi %3, %c128_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %2, %4 : i32 + %8 = arith.divsi %0, %4 : i32 + %9 = arith.remsi %0, %4 : i32 + %10 = arith.muli %8, %c256_i32 : i32 + %11 = arith.muli %9, %c128_i32 : i32 + %12 = arith.extsi %arg3 : i32 to i64 + %13 = arith.extsi %arg5 : i32 to i64 + %14 = arith.extsi %arg6 : i32 to i64 + %15 = tt.make_tensor_ptr %arg0, [%12, %13], [%14, %c1_i64], [%10, %c0_i32] {order = array} : , 1> + %16 = arith.extsi %arg4 : i32 to i64 + %17 = arith.extsi %arg7 : i32 to i64 + %18 = tt.make_tensor_ptr %arg1, [%13, %16], [%c1_i64, %17], [%c0_i32, %11] {order = array} : , 1> + %19 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %21 = tt.splat %arg8 : (i32) -> tensor<256x1xi32, #blocked2> + %22 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x1x!tt.ptr, #blocked2> + %23:4 = scf.for %arg9 = %0 to %7 step %c114_i32 iter_args(%arg10 = %15, %arg11 = %18, %arg12 = %8, %arg13 = %9) -> (!tt.ptr, 1>, !tt.ptr, 1>, i32, i32) : i32 { + %24 = arith.divsi %arg9, %4 : i32 + %25 = arith.remsi %arg9, %4 : i32 + %26 = "triton_gpu.cmpi"(%arg9, %c114_i32) {predicate = 5 : i64} : (i32, i32) -> i1 + %27:2 = scf.if %26 -> (!tt.ptr, 1>, !tt.ptr, 1>) { + %43 = arith.subi %24, %arg12 : i32 + %44 = arith.muli %43, %c256_i32 : i32 + %45 = arith.subi %c0_i32, %6 : i32 + %46 = arith.muli %45, %c64_i32 : i32 + %47 = tt.advance %arg10, [%44, %46] : , 1> + %48 = arith.subi %25, %arg13 : i32 + %49 = arith.muli %48, %c128_i32 : i32 + %50 = tt.advance %arg11, [%46, %49] : , 1> + scf.yield %47, %50 : !tt.ptr, 1>, !tt.ptr, 1> + } else { + scf.yield %arg10, %arg11 : !tt.ptr, 1>, !tt.ptr, 1> + } + %28:3 = scf.for %arg14 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg15 = %cst, %arg16 = %27#0, %arg17 = %27#1) -> (tensor<256x128xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1>) : i32 { + %43 = tt.load %arg16 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<256x64xf16, #blocked3> + %44 = tt.load %arg17 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<64x128xf16, #blocked4> + %45 = triton_gpu.convert_layout %43 : (tensor<256x64xf16, #blocked3>) -> tensor<256x64xf16, #shared> + %46 = triton_gpu.convert_layout %44 : (tensor<64x128xf16, #blocked4>) -> tensor<64x128xf16, #shared1> + %47 = tt.dot %45, %46, %arg15 {allowTF32 = true} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma> + %48 = tt.advance %arg16, [%c0_i32, %c64_i32] : , 1> + %49 = tt.advance %arg17, [%c64_i32, %c0_i32] : , 1> + scf.yield %47, %48, %49 : tensor<256x128xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> + } + %29 = arith.muli %24, %c256_i32 : i32 + %30 = tt.splat %29 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %31 = arith.addi %19, %30 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %32 = arith.muli %25, %c128_i32 : i32 + %33 = tt.splat %32 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %34 = arith.addi %20, %33 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %35 = tt.expand_dims %31 {axis = 1 : i32} : (tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<256x1xi32, #blocked2> + %36 = arith.muli %35, %21 : tensor<256x1xi32, #blocked2> + %37 = tt.addptr %22, %36 : tensor<256x1x!tt.ptr, #blocked2>, tensor<256x1xi32, #blocked2> + %38 = tt.expand_dims %34 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2> + %39 = tt.broadcast %37 : (tensor<256x1x!tt.ptr, #blocked2>) -> tensor<256x128x!tt.ptr, #blocked2> + %40 = tt.broadcast %38 : (tensor<1x128xi32, #blocked2>) -> tensor<256x128xi32, #blocked2> + %41 = tt.addptr %39, %40 : tensor<256x128x!tt.ptr, #blocked2>, tensor<256x128xi32, #blocked2> + %42 = triton_gpu.convert_layout %28#0 : (tensor<256x128xf32, #mma>) -> tensor<256x128xf32, #blocked2> + tt.store %41, %42 {cache = 1 : i32, evict = 1 : i32} : tensor<256x128xf32, #blocked2> + scf.yield %28#1, %28#2, %24, %25 : !tt.ptr, 1>, !tt.ptr, 1>, i32, i32 + } + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked4 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK: "triton_gpu.enable-warp-specialization" = 1 : i32 + // CHECK-LABEL: @matched_load_type + tt.func public @matched_load_type( + %arg0: !tt.ptr, 1>, + %arg1: !tt.ptr, 1>, + %arg2: tensor<256x128x!tt.ptr, #blocked2> + ) { + %cst = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + scf.for %iv = %c0 to %c10 step %c1 iter_args() -> () { + %a = tt.load %arg0 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<256x64xf16, #blocked3> + %b = tt.load %arg1 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<64x128xf16, #blocked4> + %shm_a = triton_gpu.convert_layout %a : (tensor<256x64xf16, #blocked3>) -> tensor<256x64xf16, #shared> + %shm_b = triton_gpu.convert_layout %b : (tensor<64x128xf16, #blocked4>) -> tensor<64x128xf16, #shared1> + %d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma> + %out = triton_gpu.convert_layout %d : (tensor<256x128xf32, #mma>) -> tensor<256x128xf32, #blocked2> + tt.store %arg2, %out {cache = 1 : i32, evict = 1 : i32} : tensor<256x128xf32, #blocked2> + } + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked4 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK: "triton_gpu.enable-warp-specialization" = 0 : i32 + // CHECK-LABEL: @mismatch_load_type + tt.func public @mismatch_load_type( + %arg0: !tt.ptr, 1>, + %arg1: tensor<64x128x!tt.ptr, #blocked4>, + %arg2: tensor<256x128x!tt.ptr, #blocked2> + ) { + %cst = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + scf.for %iv = %c0 to %c10 step %c1 iter_args() -> () { + %a = tt.load %arg0 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<256x64xf16, #blocked3> + %b = tt.load %arg1 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x128x!tt.ptr, #blocked4> -> tensor<64x128xf16, #blocked4> + %shm_a = triton_gpu.convert_layout %a : (tensor<256x64xf16, #blocked3>) -> tensor<256x64xf16, #shared> + %shm_b = triton_gpu.convert_layout %b : (tensor<64x128xf16, #blocked4>) -> tensor<64x128xf16, #shared1> + %d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma> + %out = triton_gpu.convert_layout %d : (tensor<256x128xf32, #mma>) -> tensor<256x128xf32, #blocked2> + tt.store %arg2, %out {cache = 1 : i32, evict = 1 : i32} : tensor<256x128xf32, #blocked2> + } + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK: "triton_gpu.enable-warp-specialization" = 1 : i32 + // CHECK-LABEL: @epilogue_with_reduce + tt.func public @epilogue_with_reduce(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + %c15_i32 = arith.constant 15 : i32 + %c63_i32 = arith.constant 63 : i32 + %c132_i32 = arith.constant 132 : i32 + %c16_i32 = arith.constant 16 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c64_i32 = arith.constant 64 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg6, %c63_i32 : i32 + %2 = arith.divsi %1, %c64_i32 : i32 + %3 = arith.addi %arg5, %c63_i32 : i32 + %4 = arith.divsi %3, %c64_i32 : i32 + %5 = arith.muli %4, %2 : i32 + %6 = arith.muli %2, %c8_i32 : i32 + %7 = arith.divsi %0, %6 : i32 + %8 = arith.muli %7, %c8_i32 : i32 + %9 = arith.subi %4, %8 : i32 + %10 = "triton_gpu.cmpi"(%9, %c8_i32) {predicate = 2 : i64} : (i32, i32) -> i1 + %11 = arith.select %10, %9, %c8_i32 : i32 + %12 = arith.remsi %0, %6 : i32 + %13 = arith.remsi %12, %11 : i32 + %14 = arith.addi %8, %13 : i32 + %15 = arith.divsi %12, %11 : i32 + %16 = arith.muli %14, %c64_i32 : i32 + %17 = arith.muli %15, %c64_i32 : i32 + %18 = arith.extsi %arg5 : i32 to i64 + %19 = arith.extsi %arg7 : i32 to i64 + %20 = arith.extsi %arg8 : i32 to i64 + %21 = tt.make_tensor_ptr %arg0, [%18, %19], [%20, %c1_i64], [%16, %c0_i32] {order = array} : , 1> + %22 = arith.extsi %arg6 : i32 to i64 + %23 = arith.extsi %arg9 : i32 to i64 + %24 = tt.make_tensor_ptr %arg1, [%19, %22], [%23, %c1_i64], [%c0_i32, %17] {order = array} : , 1> + %25 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %26 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %27 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %28 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %29 = tt.splat %arg10 : (i32) -> tensor<64x1xi32, #blocked2> + %30 = tt.splat %arg4 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked2> + %31 = tt.splat %arg5 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %32 = tt.splat %arg6 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %33 = arith.addi %arg7, %c15_i32 : i32 + %34 = arith.divsi %33, %c16_i32 : i32 + %35 = arith.subi %c0_i32, %34 : i32 + %36 = arith.muli %35, %c16_i32 : i32 + %37:4 = scf.for %arg11 = %0 to %5 step %c132_i32 iter_args(%arg12 = %21, %arg13 = %24, %arg14 = %14, %arg15 = %15) -> (!tt.ptr, 1>, !tt.ptr, 1>, i32, i32) : i32 { + %38 = arith.divsi %arg11, %6 : i32 + %39 = arith.muli %38, %c8_i32 : i32 + %40 = arith.subi %4, %39 : i32 + %41 = "triton_gpu.cmpi"(%40, %c8_i32) {predicate = 2 : i64} : (i32, i32) -> i1 + %42 = arith.select %41, %40, %c8_i32 : i32 + %43 = arith.remsi %arg11, %6 : i32 + %44 = arith.remsi %43, %42 : i32 + %45 = arith.addi %39, %44 : i32 + %46 = arith.divsi %43, %42 : i32 + %47 = arith.muli %45, %c64_i32 : i32 + %48 = arith.muli %46, %c64_i32 : i32 + %49 = tt.splat %47 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %50 = tt.splat %47 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %51 = arith.addi %49, %26 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %52 = arith.addi %50, %28 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %53 = tt.splat %48 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %54 = tt.splat %48 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %55 = arith.addi %53, %25 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %56 = arith.addi %54, %27 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %57 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64x1xi32, #blocked2> + %58 = arith.muli %57, %29 : tensor<64x1xi32, #blocked2> + %59 = tt.addptr %30, %58 : tensor<64x1x!tt.ptr, #blocked2>, tensor<64x1xi32, #blocked2> + %60 = tt.expand_dims %55 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2> + %61 = tt.broadcast %59 : (tensor<64x1x!tt.ptr, #blocked2>) -> tensor<64x64x!tt.ptr, #blocked2> + %62 = tt.broadcast %60 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> + %63 = tt.addptr %61, %62 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> + %64 = "triton_gpu.cmpi"(%52, %31) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %65 = tt.expand_dims %64 {axis = 1 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64x1xi1, #blocked2> + %66 = "triton_gpu.cmpi"(%56, %32) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %67 = tt.expand_dims %66 {axis = 0 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi1, #blocked2> + %68 = tt.broadcast %65 : (tensor<64x1xi1, #blocked2>) -> tensor<64x64xi1, #blocked2> + %69 = tt.broadcast %67 : (tensor<1x64xi1, #blocked2>) -> tensor<64x64xi1, #blocked2> + %70 = arith.andi %68, %69 : tensor<64x64xi1, #blocked2> + %71 = arith.subi %45, %arg14 : i32 + %72 = arith.muli %71, %c64_i32 : i32 + %73 = tt.advance %arg12, [%72, %c0_i32] : , 1> + %74 = arith.subi %46, %arg15 : i32 + %75 = arith.muli %74, %c64_i32 : i32 + %76 = tt.advance %arg13, [%c0_i32, %75] : , 1> + %77:3 = scf.for %arg16 = %c0_i32 to %arg7 step %c16_i32 iter_args(%arg17 = %cst, %arg18 = %73, %arg19 = %76) -> (tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1>) : i32 { + %91 = tt.load %arg18 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<64x16xf16, #blocked3> + %92 = tt.load %arg19 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x64xf16, #blocked4> + %93 = triton_gpu.convert_layout %91 : (tensor<64x16xf16, #blocked3>) -> tensor<64x16xf16, #shared> + %94 = triton_gpu.convert_layout %92 : (tensor<16x64xf16, #blocked4>) -> tensor<16x64xf16, #shared1> + %95 = tt.dot %93, %94, %arg17 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %96 = tt.advance %arg18, [%c0_i32, %c16_i32] : , 1> + %97 = tt.advance %arg19, [%c16_i32, %c0_i32] : , 1> + scf.yield %95, %96, %97 : tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> + } + %78 = triton_gpu.convert_layout %77#0 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked2> + %79 = triton_gpu.convert_layout %77#0 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked2> + %80 = tt.advance %77#1, [%c0_i32, %36] : , 1> + %81 = tt.advance %77#2, [%36, %c0_i32] : , 1> + %82 = "tt.reduce"(%78) ({ + ^bb0(%arg16: f32, %arg17: f32): + %91 = "triton_gpu.cmpf"(%arg16, %arg17) {predicate = 2 : i64} : (f32, f32) -> i1 + %92 = arith.select %91, %arg16, %arg17 : f32 + tt.reduce.return %92 : f32 + }) {axis = 1 : i32} : (tensor<64x64xf32, #blocked2>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %83 = tt.expand_dims %82 {axis = 1 : i32} : (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64x1xf32, #blocked2> + %84 = tt.broadcast %83 : (tensor<64x1xf32, #blocked2>) -> tensor<64x64xf32, #blocked2> + %85 = arith.subf %79, %84 : tensor<64x64xf32, #blocked2> + %86 = math.exp %85 : tensor<64x64xf32, #blocked2> + %87 = "tt.reduce"(%86) ({ + ^bb0(%arg16: f32, %arg17: f32): + %91 = arith.addf %arg16, %arg17 : f32 + tt.reduce.return %91 : f32 + }) {axis = 1 : i32} : (tensor<64x64xf32, #blocked2>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %88 = tt.expand_dims %87 {axis = 1 : i32} : (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64x1xf32, #blocked2> + %89 = tt.broadcast %88 : (tensor<64x1xf32, #blocked2>) -> tensor<64x64xf32, #blocked2> + %90 = arith.divf %86, %89 : tensor<64x64xf32, #blocked2> + tt.store %63, %90, %70 {cache = 1 : i32, evict = 1 : i32} : tensor<64x64xf32, #blocked2> + scf.yield %80, %81, %45, %46 : !tt.ptr, 1>, !tt.ptr, 1>, i32, i32 + } + tt.return + } +} diff --git a/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt b/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt index 691cd8bfc679..c73e912ce265 100644 --- a/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -3,8 +3,15 @@ add_triton_ut( SRCS PTXAsmFormatTest.cpp LIBS TritonGPUToLLVM ) + add_triton_ut( NAME TestGcnAsmFormat SRCS GcnAsmFormatTest.cpp LIBS TritonGPUToLLVM ) + +add_triton_ut( + NAME TestEmitIndices + SRCS EmitIndicesTest.cpp DumpLayout.cpp + LIBS TritonGPUIR TritonNvidiaGPUIR ${dialect_libs} ${conversion_libs} +) diff --git a/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp b/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp new file mode 100644 index 000000000000..a0669baa9f85 --- /dev/null +++ b/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp @@ -0,0 +1,384 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "DumpLayout.h" + +#include "../../../lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h" + +namespace mlir { +namespace triton { +namespace gpu { + +namespace { + +//===----------------------------------------------------------------------===// +// IndexEmitter +//===----------------------------------------------------------------------===// + +class IndexEmitter { +public: + struct Cache { + llvm::DenseMap, + CacheKeyDenseMapInfo> + baseIndexCache; + llvm::DenseMap>, + CacheKeyDenseMapInfo> + indexCache; + OpBuilder::InsertPoint indexInsertPoint; + }; + + IndexEmitter(MLIRContext *context_) + : context(context_), option(context), typeConverter(context, option), + cacheInfo({&cache.baseIndexCache, &cache.indexCache, + &cache.indexInsertPoint}), + base(typeConverter, cacheInfo), rewriter(context), + loc(UnknownLoc::get(context)) { + rewriter.setInsertionPointToStart(&block); + cache.indexInsertPoint = rewriter.saveInsertionPoint(); + } + + llvm::SmallVector> + emitIndices(Attribute layout, llvm::ArrayRef shape, + bool withCTAOffset) { + auto type = RankedTensorType::get(shape, rewriter.getF16Type(), layout); + return base.emitIndices(loc, rewriter, layout, type, withCTAOffset); + } + + llvm::DenseMap + emitDistributedToShared(Attribute srcLayout, SharedEncodingAttr sharedLayout, + Type elemTy, llvm::ArrayRef shape, + bool withCTAOffset) { + auto srcTy = RankedTensorType::get(shape, elemTy, srcLayout); + SharedMemoryObject smemObj(getMockSmemBase(), shape, + sharedLayout.getOrder(), loc, rewriter); + return base.getSwizzledSharedPtrs(loc, /*inVec=*/1, srcTy, sharedLayout, + elemTy, smemObj, rewriter, + smemObj.offsets, smemObj.strides); + } + +private: + Value getMockSmemBase() { + Value mockSmemBase = + mlir::LLVM::getSRegValue(rewriter, loc, "%mock_smem_base"); + auto llPtrTy = LLVM::LLVMPointerType::get( + typeConverter.convertType(rewriter.getI8Type()), 3); + auto cast = rewriter.create( + loc, TypeRange{llPtrTy}, ValueRange{mockSmemBase}); + return cast.getResult(0); + } + + // Non-static members are initialized in declaration order + MLIRContext *context; + LowerToLLVMOptions option; + TritonGPUToLLVMTypeConverter typeConverter; + Cache cache; + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo cacheInfo; + ConvertTritonGPUOpToLLVMPatternBase base; + Block block; + ConversionPatternRewriter rewriter; + Location loc; +}; + +//===----------------------------------------------------------------------===// +// MLIR expression evaluation +//===----------------------------------------------------------------------===// + +int eval(Value value, int ctaid, int tid); + +int evalThreadIdOp(mlir::gpu::ThreadIdOp threadIdOp, int ctaid, int tid) { + auto dim = threadIdOp.getDimension(); + if (dim == mlir::gpu::Dimension::x) + return tid; + else if (dim == mlir::gpu::Dimension::y) + return 0; + else if (dim == mlir::gpu::Dimension::z) + return 0; + else + assert(0 && "Invalid thread dim"); + return 0; +} + +int evalInlineAsmOp(mlir::LLVM::InlineAsmOp asmOp, int ctaid, int tid) { + std::string asmStr = asmOp.getAsmString().str(); + if (asmStr.find("%cluster_ctaid.x") != std::string::npos) + return ctaid; + else if (asmStr.find("%cluster_ctaid.y") != std::string::npos) + return 0; + else if (asmStr.find("%cluster_ctaid.z") != std::string::npos) + return 0; + else if (asmStr.find("%cluster_nctaid.x") != std::string::npos) + assert(0 && "%cluster_nctaid.x not supported"); + else if (asmStr.find("%cluster_nctaid.y") != std::string::npos) + return 1; + else if (asmStr.find("%cluster_nctaid.z") != std::string::npos) + return 1; + else if (asmStr.find("%mock_smem_base") != std::string::npos) + return 0; + else + assert(0 && "Unrecognized ASM string"); + return 0; +} + +int evalGEPOp(mlir::LLVM::GEPOp gepOp, int ctaid, int tid) { + assert(gepOp.getNumOperands() == 2 && "Unrecognized format of GEPOp"); + int base = eval(gepOp.getBase(), ctaid, tid); + int offset = eval(gepOp.getOperand(1), ctaid, tid); + auto llPtrTy = gepOp.getRes().getType().cast(); + int bytesPerElem = llPtrTy.getElementType().getIntOrFloatBitWidth() / 8; + return base + offset * bytesPerElem; +} + +int eval(Value value, int ctaid, int tid) { + Operation *op = value.getDefiningOp(); + assert(op && "Unrecognized source value in the index expression"); + if (auto constantOp = llvm::dyn_cast(op)) { + auto attr = constantOp.getValue(); + return attr.cast().getInt(); + } else if (auto addOp = llvm::dyn_cast(op)) { + return eval(addOp.getLhs(), ctaid, tid) + eval(addOp.getRhs(), ctaid, tid); + } else if (auto mulOp = llvm::dyn_cast(op)) { + return eval(mulOp.getLhs(), ctaid, tid) * eval(mulOp.getRhs(), ctaid, tid); + } else if (auto udivOp = llvm::dyn_cast(op)) { + return eval(udivOp.getLhs(), ctaid, tid) / + eval(udivOp.getRhs(), ctaid, tid); + } else if (auto uremOp = llvm::dyn_cast(op)) { + return eval(uremOp.getLhs(), ctaid, tid) % + eval(uremOp.getRhs(), ctaid, tid); + } else if (auto xorOp = llvm::dyn_cast(op)) { + return eval(xorOp.getLhs(), ctaid, tid) ^ eval(xorOp.getRhs(), ctaid, tid); + } else if (auto trunciOp = llvm::dyn_cast(op)) { + return eval(trunciOp.getIn(), ctaid, tid); + } else if (auto castOp = llvm::dyn_cast(op)) { + return eval(castOp.getOperand(0), ctaid, tid); + } else if (auto threadOp = llvm::dyn_cast(op)) { + return evalThreadIdOp(threadOp, ctaid, tid); + } else if (auto asmOp = llvm::dyn_cast(op)) { + return evalInlineAsmOp(asmOp, ctaid, tid); + } else if (auto gepOp = llvm::dyn_cast(op)) { + return evalGEPOp(gepOp, ctaid, tid); + } else { + assert(0 && "Unrecognized op type in the index expression"); + return 0; + } +} + +} // namespace + +//===----------------------------------------------------------------------===// +// Dump Distributed Layout +//===----------------------------------------------------------------------===// + +std::string dumpDistributedLayout(Attribute layout, + llvm::ArrayRef shape, + bool multiCTA) { + assert(isaDistributedLayout(layout) && + "Unsupported layout type for dumpDistributedLayout"); + + assert(shape.size() > 0 && "Empty shape"); + assert(shape.size() <= 2 && + "High order tensor is not supported in dumpLayout"); + + int numThreads = 32 * getNumWarpsPerCTA(layout); + int numCTAs = getNumCTAs(layout); + auto f16Ty = FloatType::getF16(layout.getContext()); + int numElems = getTotalElemsPerThread(layout, shape, f16Ty); + + if (!multiCTA) + assert(numCTAs == 1 && "numCTAs must be 1 when multiCTA is false"); + + IndexEmitter emitter(layout.getContext()); + auto indices = emitter.emitIndices(layout, shape, multiCTA); + assert(indices.size() == numElems && "Incorrect number of indices emitted"); + + auto genStr = [multiCTA](int ctaid, int tid, int idx) -> std::string { + std::ostringstream oss; + if (multiCTA) + oss << "CTA" << ctaid << ":"; + oss << "T" << tid << ":" << idx; + return oss.str(); + }; + + std::ostringstream oss; + + auto dumpLayout1d = [&]() { + for (int idx = 0; idx < numElems; ++idx) + assert(indices[idx].size() == 1 && "Incorrect rank of indices emitted"); + + int size = shape[0]; + std::vector mapping(size); + + for (int ctaid = 0; ctaid < numCTAs; ++ctaid) { + for (int tid = 0; tid < numThreads; ++tid) { + for (int idx = 0; idx < numElems; ++idx) { + int i = eval(indices[idx][0], ctaid, tid); + assert(i >= 0 && i < size && "Invalid index emitted"); + std::string &value = mapping[i]; + if (value.empty()) + value = genStr(ctaid, tid, idx); + else + value = value + "|" + genStr(ctaid, tid, idx); + } + } + } + + for (int i = 0; i < size; ++i) { + if (i > 0) + oss << ","; + oss << mapping[i]; + } + oss << "\n"; + }; + + auto dumpLayout2d = [&]() { + for (int idx = 0; idx < numElems; ++idx) + assert(indices[idx].size() == 2 && "Incorrect rank of indices emitted"); + + int row = shape[0], col = shape[1]; + std::vector> mapping( + row, std::vector(col)); + + for (int ctaid = 0; ctaid < numCTAs; ++ctaid) { + for (int tid = 0; tid < numThreads; ++tid) { + for (int idx = 0; idx < numElems; ++idx) { + int r = eval(indices[idx][0], ctaid, tid); + int c = eval(indices[idx][1], ctaid, tid); + assert(r >= 0 && r < row && c >= 0 && c < col && + "Invalid index emitted"); + std::string &value = mapping[r][c]; + if (value.empty()) + value = genStr(ctaid, tid, idx); + else + value = value + "|" + genStr(ctaid, tid, idx); + } + } + } + + for (int r = 0; r < row; ++r) { + for (int c = 0; c < col; ++c) { + if (c > 0) + oss << ","; + oss << mapping[r][c]; + } + oss << "\n"; + } + }; + + if (shape.size() == 1) + dumpLayout1d(); + else + dumpLayout2d(); + + return oss.str(); +} + +//===----------------------------------------------------------------------===// +// Dump Shared Layout +//===----------------------------------------------------------------------===// + +std::string dumpSharedLayout(Attribute layout, llvm::ArrayRef shape, + Type elemTy, bool multiCTA) { + assert(shape.size() == 2 && "Only 2d shape supported in dumpSharedLayout"); + int row = shape[0], col = shape[1]; + int size = row * col; + int bytesPerElem = elemTy.getIntOrFloatBitWidth() / 8; + int totalBytes = size * bytesPerElem; + + int numWarps = 1; + int numThreads = 32 * numWarps; + int numCTAs = getNumCTAs(layout); + + if (!multiCTA) + assert(numCTAs == 1 && "numCTAs must be 1 when multiCTA is false"); + + auto sharedLayout = layout.cast(); + auto blockedLayout = BlockedEncodingAttr::get( + /*context=*/layout.getContext(), /*shape=*/shape, + /*sizePerThread=*/{1, 1}, /*order=*/sharedLayout.getOrder(), + /*numWarps=*/numWarps, 32, /*CTALayout=*/sharedLayout.getCTALayout()); + + int numElems = getTotalElemsPerThread(blockedLayout, shape, elemTy); + + IndexEmitter emitter(layout.getContext()); + auto blockedIndices = emitter.emitIndices(blockedLayout, shape, multiCTA); + auto sharedPtrs = emitter.emitDistributedToShared(blockedLayout, sharedLayout, + elemTy, shape, multiCTA); + + assert(blockedIndices.size() == numElems && + "Incorrect number of indices emitted by blockedLayout"); + assert(sharedPtrs.size() == numElems && + "Incorrect number of pointers emitted by sharedLayout"); + + for (int idx = 0; idx < numElems; ++idx) + assert(blockedIndices[idx].size() == 2 && + "Incorrect rank of indices emitted by blockedLayout"); + + auto genStr = [](int r, int c) -> std::string { + std::ostringstream oss; + oss << "(" << r << ":" << c << ")"; + return oss.str(); + }; + + std::vector mapping(size); + for (int ctaid = 0; ctaid < numCTAs; ++ctaid) { + for (int tid = 0; tid < numThreads; ++tid) { + for (int idx = 0; idx < numElems; ++idx) { + int r = eval(blockedIndices[idx][0], ctaid, tid); + int c = eval(blockedIndices[idx][1], ctaid, tid); + assert(r >= 0 && r < row && c >= 0 && c < col && + "Invalid index emitted"); + int ptr = eval(sharedPtrs[idx], ctaid, tid); + assert(ptr % bytesPerElem == 0 && ptr < totalBytes && + "Invalid pointer emitted"); + std::string &value = mapping[ptr / bytesPerElem]; + if (value.empty()) + value = genStr(r, c); + else + value = value + "|" + genStr(r, c); + } + } + } + + const int bytesPerBank = 4; + const int totalBanks = 32; + const int bytesPerLine = + std::min(col * bytesPerElem, bytesPerBank * totalBanks); + int elemsPerLine = bytesPerLine / bytesPerElem; + + std::ostringstream oss; + + for (int i = 0; i < size; ++i) { + int r = i / elemsPerLine; + int c = i % elemsPerLine; + if (c > 0) + oss << ","; + oss << mapping[i]; + if (c == elemsPerLine - 1) + oss << "\n"; + } + + return oss.str(); +} + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/unittest/Conversion/TritonGPUToLLVM/DumpLayout.h b/unittest/Conversion/TritonGPUToLLVM/DumpLayout.h new file mode 100644 index 000000000000..6e600f6ac779 --- /dev/null +++ b/unittest/Conversion/TritonGPUToLLVM/DumpLayout.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_UNITTEST_CONVERSION_TRITONGPU_TO_LLVM_DUMP_LAYOUT_H +#define TRITON_UNITTEST_CONVERSION_TRITONGPU_TO_LLVM_DUMP_LAYOUT_H + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace gpu { + +std::string dumpDistributedLayout(Attribute layout, + llvm::ArrayRef shape, bool multiCTA); + +std::string dumpSharedLayout(Attribute layout, llvm::ArrayRef shape, + Type elemTy, bool multiCTA); + +} // namespace gpu +} // namespace triton +} // namespace mlir + +#endif diff --git a/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp b/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp new file mode 100644 index 000000000000..90e6ef8c3d1d --- /dev/null +++ b/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp @@ -0,0 +1,677 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" + +#include "DumpLayout.h" + +#include +#include + +namespace mlir { +namespace triton { +namespace gpu { + +//===----------------------------------------------------------------------===// +// EmitIndicesTest +//===----------------------------------------------------------------------===// + +class EmitIndicesTest : public ::testing::Test { +public: + void SetUp() { + context.getOrLoadDialect(); + context.getOrLoadDialect(); + context.getOrLoadDialect(); + } + +protected: + void runBlocked1dSingleCTA(int size, unsigned sizePerThread, + unsigned warpsPerCTA, const std::string &refStr) { + // If we pass initializer lists to the constructor of BlockedEncodingAttr, + // there might be multiple constructors matching the same parameter list. + // For example, the initializer list "order = {0}" can also match the + // parameter "unsigned numWarps", which is not what we want + llvm::SmallVector sizePerThread_ = {sizePerThread}; + llvm::SmallVector threadsPerWarp = {32}; + llvm::SmallVector warpsPerCTA_ = {warpsPerCTA}; + llvm::SmallVector order = {0}; + auto layout = + BlockedEncodingAttr::get(&context, sizePerThread_, threadsPerWarp, + warpsPerCTA_, order, getSingleCTALayout1d()); + runDistributed1d(size, layout, /*multiCTA=*/false, refStr); + } + + void runBlocked2dSingleCTA(int row, int col, + llvm::ArrayRef sizePerThread, + llvm::ArrayRef threadsPerWarp, + llvm::ArrayRef warpsPerCTA, + llvm::ArrayRef order, + const std::string &refStr) { + auto layout = + BlockedEncodingAttr::get(&context, sizePerThread, threadsPerWarp, + warpsPerCTA, order, getSingleCTALayout2d()); + runDistributed2d(row, col, layout, /*multiCTA=*/false, refStr); + } + + void runBlocked2dMultiCTA( + int row, int col, llvm::ArrayRef sizePerThread, + llvm::ArrayRef threadsPerWarp, + llvm::ArrayRef warpsPerCTA, llvm::ArrayRef order, + llvm::ArrayRef CTAsPerCGA, llvm::ArrayRef CTASplitNum, + llvm::ArrayRef CTAOrder, const std::string &refStr) { + auto CTALayout = + CTALayoutAttr::get(&context, CTAsPerCGA, CTASplitNum, CTAOrder); + auto layout = BlockedEncodingAttr::get( + &context, sizePerThread, threadsPerWarp, warpsPerCTA, order, CTALayout); + runDistributed2d(row, col, layout, /*multiCTA=*/true, refStr); + } + + void runSliceBlockedSingleCTA(int size, + llvm::ArrayRef sizePerThread, + llvm::ArrayRef threadsPerWarp, + llvm::ArrayRef warpsPerCTA, + llvm::ArrayRef order, + unsigned sliceDim, const std::string &refStr) { + auto parent = + BlockedEncodingAttr::get(&context, sizePerThread, threadsPerWarp, + warpsPerCTA, order, getSingleCTALayout2d()); + auto layout = SliceEncodingAttr::get(&context, sliceDim, parent); + runDistributed1d(size, layout, /*multiCTA=*/false, refStr); + } + + void runSliceBlockedMultiCTA(int size, llvm::ArrayRef sizePerThread, + llvm::ArrayRef threadsPerWarp, + llvm::ArrayRef warpsPerCTA, + llvm::ArrayRef order, + llvm::ArrayRef CTAsPerCGA, + llvm::ArrayRef CTASplitNum, + llvm::ArrayRef CTAOrder, + unsigned sliceDim, const std::string &refStr) { + auto CTALayout = + CTALayoutAttr::get(&context, CTAsPerCGA, CTASplitNum, CTAOrder); + auto parent = BlockedEncodingAttr::get( + &context, sizePerThread, threadsPerWarp, warpsPerCTA, order, CTALayout); + auto layout = SliceEncodingAttr::get(&context, sliceDim, parent); + runDistributed1d(size, layout, /*multiCTA=*/true, refStr); + } + + void runMmaSingleCTA(int row, int col, unsigned versionMajor, + unsigned versionMinor, + llvm::ArrayRef warpsPerCTA, + llvm::ArrayRef instrShape, + const std::string &refStr) { + auto layout = + MmaEncodingAttr::get(&context, versionMajor, versionMinor, warpsPerCTA, + getSingleCTALayout2d(), instrShape); + runDistributed2d(row, col, layout, /*multiCTA=*/false, refStr); + } + + void runDotOpSingleCTA(int row, int col, unsigned versionMajor, + unsigned versionMinor, + llvm::ArrayRef warpsPerCTA, + llvm::ArrayRef instrShape, unsigned opIdx, + const std::string &refStr) { + auto parent = + MmaEncodingAttr::get(&context, versionMajor, versionMinor, warpsPerCTA, + getSingleCTALayout2d(), instrShape); + auto layout = DotOperandEncodingAttr::get(&context, opIdx, parent, 0); + runDistributed2d(row, col, layout, /*multiCTA=*/false, refStr); + } + + void runSharedSingleCTA(int row, int col, bool rowMajor, + const std::string &elemTyStr, + const std::string &refStr) { + auto elemTy = getElemTy(elemTyStr); + auto layout = + SharedEncodingAttr::get(&context, {row, col}, getMatrixOrder(rowMajor), + getSingleCTALayout2d(), elemTy); + llvm::outs() << layout << "\n"; + runShared(row, col, layout, elemTy, /*multiCTA=*/false, refStr); + } + +private: + std::string skipSpaces(const std::string &input) { + std::string output; + for (char c : input) + if (c != ' ') + output += c; + return output; + } + + void assertSameStr(const std::string &refStr, const std::string &output) { + if (refStr != output) { + llvm::outs() << "RefStr =\n" + << refStr << "\n" + << "\n" + << "Output =\n" + << output << "\n"; + FAIL() << "Incorrect output string"; + } + } + + void runDistributed1d(int size, Attribute layout, bool multiCTA, + const std::string &refStr) { + assertSameStr(skipSpaces(refStr), + dumpDistributedLayout(layout, {size}, multiCTA)); + } + + void runDistributed2d(int row, int col, Attribute layout, bool multiCTA, + const std::string &refStr) { + assertSameStr(skipSpaces(refStr), + dumpDistributedLayout(layout, {row, col}, multiCTA)); + } + + void runShared(int row, int col, const SharedEncodingAttr &layout, + Type elemTy, bool multiCTA, const std::string &refStr) { + assertSameStr(skipSpaces(refStr), + dumpSharedLayout(layout, {row, col}, elemTy, multiCTA)); + } + + CTALayoutAttr getSingleCTALayout1d() { + return CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{1}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + } + + CTALayoutAttr getSingleCTALayout2d() { + return CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{1, 1}, + /*CTASplitNum=*/{1, 1}, /*CTAOrder=*/{1, 0}); + } + + llvm::SmallVector getMatrixOrder(bool rowMajor) { + if (rowMajor) + return {1, 0}; + else + return {0, 1}; + } + + Type getElemTy(const std::string &elemTyStr) { + if (elemTyStr == "F16") + return FloatType::getF16(&context); + else + assert(0 && "getElemTy not implemented"); + } + +protected: + MLIRContext context; +}; + +//===----------------------------------------------------------------------===// +// Tests for BlockedEncodingAttr +//===----------------------------------------------------------------------===// + +TEST_F(EmitIndicesTest, BlockedLayout_SingleCTA_1D) { + // clang-format off + std::string refStr = + "T0:0,T1:0,T2:0,T3:0,T4:0,T5:0,T6:0,T7:0,T8:0,T9:0,T10:0,T11:0,T12:0,T13:0,T14:0,T15:0,T16:0,T17:0,T18:0,T19:0,T20:0,T21:0,T22:0,T23:0,T24:0,T25:0,T26:0,T27:0,T28:0,T29:0,T30:0,T31:0\n"; + // clang-format on + + runBlocked1dSingleCTA(/*size=*/32, /*sizePerThread*/ 1, /*warpsPerCTA*/ 1, + /*refStr=*/refStr); +} + +TEST_F(EmitIndicesTest, BlockedLayout_SingleCTA_Order_1_0) { + // clang-format off + std::string refStr = + " T0:0, T1:0, T2:0, T3:0, T32:0, T33:0, T34:0, T35:0\n" + " T4:0, T5:0, T6:0, T7:0, T36:0, T37:0, T38:0, T39:0\n" + " T8:0, T9:0,T10:0,T11:0, T40:0, T41:0, T42:0, T43:0\n" + "T12:0,T13:0,T14:0,T15:0, T44:0, T45:0, T46:0, T47:0\n" + "T16:0,T17:0,T18:0,T19:0, T48:0, T49:0, T50:0, T51:0\n" + "T20:0,T21:0,T22:0,T23:0, T52:0, T53:0, T54:0, T55:0\n" + "T24:0,T25:0,T26:0,T27:0, T56:0, T57:0, T58:0, T59:0\n" + "T28:0,T29:0,T30:0,T31:0, T60:0, T61:0, T62:0, T63:0\n" + + "T64:0,T65:0,T66:0,T67:0, T96:0, T97:0, T98:0, T99:0\n" + "T68:0,T69:0,T70:0,T71:0, T100:0,T101:0,T102:0,T103:0\n" + "T72:0,T73:0,T74:0,T75:0, T104:0,T105:0,T106:0,T107:0\n" + "T76:0,T77:0,T78:0,T79:0, T108:0,T109:0,T110:0,T111:0\n" + "T80:0,T81:0,T82:0,T83:0, T112:0,T113:0,T114:0,T115:0\n" + "T84:0,T85:0,T86:0,T87:0, T116:0,T117:0,T118:0,T119:0\n" + "T88:0,T89:0,T90:0,T91:0, T120:0,T121:0,T122:0,T123:0\n" + "T92:0,T93:0,T94:0,T95:0, T124:0,T125:0,T126:0,T127:0\n"; + // clang-format on + + runBlocked2dSingleCTA(/*row=*/16, /*col=*/8, /*sizePerThread=*/{1, 1}, + /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{2, 2}, + /*order=*/{1, 0}, /*refStr=*/refStr); +} + +TEST_F(EmitIndicesTest, BlockedLayout_SingleCTA_Order_0_1) { + // clang-format off + std::string refStr = + " T0:0, T8:0,T16:0,T24:0, T64:0, T72:0, T80:0, T88:0\n" + " T1:0, T9:0,T17:0,T25:0, T65:0, T73:0, T81:0, T89:0\n" + " T2:0,T10:0,T18:0,T26:0, T66:0, T74:0, T82:0, T90:0\n" + " T3:0,T11:0,T19:0,T27:0, T67:0, T75:0, T83:0, T91:0\n" + " T4:0,T12:0,T20:0,T28:0, T68:0, T76:0, T84:0, T92:0\n" + " T5:0,T13:0,T21:0,T29:0, T69:0, T77:0, T85:0, T93:0\n" + " T6:0,T14:0,T22:0,T30:0, T70:0, T78:0, T86:0, T94:0\n" + " T7:0,T15:0,T23:0,T31:0, T71:0, T79:0, T87:0, T95:0\n" + + "T32:0,T40:0,T48:0,T56:0, T96:0,T104:0,T112:0,T120:0\n" + "T33:0,T41:0,T49:0,T57:0, T97:0,T105:0,T113:0,T121:0\n" + "T34:0,T42:0,T50:0,T58:0, T98:0,T106:0,T114:0,T122:0\n" + "T35:0,T43:0,T51:0,T59:0, T99:0,T107:0,T115:0,T123:0\n" + "T36:0,T44:0,T52:0,T60:0, T100:0,T108:0,T116:0,T124:0\n" + "T37:0,T45:0,T53:0,T61:0, T101:0,T109:0,T117:0,T125:0\n" + "T38:0,T46:0,T54:0,T62:0, T102:0,T110:0,T118:0,T126:0\n" + "T39:0,T47:0,T55:0,T63:0, T103:0,T111:0,T119:0,T127:0\n"; + // clang-format on + + runBlocked2dSingleCTA(/*row=*/16, /*col=*/8, /*sizePerThread=*/{1, 1}, + /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{2, 2}, + /*order=*/{0, 1}, /*refStr=*/refStr); +} + +TEST_F(EmitIndicesTest, BlockedLayout_SingleCTA_Vectorize) { + // clang-format off + std::string refStr = + " T0:0, T0:1, T0:2, T0:3, T1:0, T1:1, T1:2, T1:3, T2:0, T2:1, T2:2, T2:3, T3:0, T3:1, T3:2, T3:3\n" + " T4:0, T4:1, T4:2, T4:3, T5:0, T5:1, T5:2, T5:3, T6:0, T6:1, T6:2, T6:3, T7:0, T7:1, T7:2, T7:3\n" + " T8:0, T8:1, T8:2, T8:3, T9:0, T9:1, T9:2, T9:3, T10:0,T10:1,T10:2,T10:3, T11:0,T11:1,T11:2,T11:3\n" + "T12:0,T12:1,T12:2,T12:3, T13:0,T13:1,T13:2,T13:3, T14:0,T14:1,T14:2,T14:3, T15:0,T15:1,T15:2,T15:3\n" + "T16:0,T16:1,T16:2,T16:3, T17:0,T17:1,T17:2,T17:3, T18:0,T18:1,T18:2,T18:3, T19:0,T19:1,T19:2,T19:3\n" + "T20:0,T20:1,T20:2,T20:3, T21:0,T21:1,T21:2,T21:3, T22:0,T22:1,T22:2,T22:3, T23:0,T23:1,T23:2,T23:3\n" + "T24:0,T24:1,T24:2,T24:3, T25:0,T25:1,T25:2,T25:3, T26:0,T26:1,T26:2,T26:3, T27:0,T27:1,T27:2,T27:3\n" + "T28:0,T28:1,T28:2,T28:3, T29:0,T29:1,T29:2,T29:3, T30:0,T30:1,T30:2,T30:3, T31:0,T31:1,T31:2,T31:3\n" + + "T32:0,T32:1,T32:2,T32:3, T33:0,T33:1,T33:2,T33:3, T34:0,T34:1,T34:2,T34:3, T35:0,T35:1,T35:2,T35:3\n" + "T36:0,T36:1,T36:2,T36:3, T37:0,T37:1,T37:2,T37:3, T38:0,T38:1,T38:2,T38:3, T39:0,T39:1,T39:2,T39:3\n" + "T40:0,T40:1,T40:2,T40:3, T41:0,T41:1,T41:2,T41:3, T42:0,T42:1,T42:2,T42:3, T43:0,T43:1,T43:2,T43:3\n" + "T44:0,T44:1,T44:2,T44:3, T45:0,T45:1,T45:2,T45:3, T46:0,T46:1,T46:2,T46:3, T47:0,T47:1,T47:2,T47:3\n" + "T48:0,T48:1,T48:2,T48:3, T49:0,T49:1,T49:2,T49:3, T50:0,T50:1,T50:2,T50:3, T51:0,T51:1,T51:2,T51:3\n" + "T52:0,T52:1,T52:2,T52:3, T53:0,T53:1,T53:2,T53:3, T54:0,T54:1,T54:2,T54:3, T55:0,T55:1,T55:2,T55:3\n" + "T56:0,T56:1,T56:2,T56:3, T57:0,T57:1,T57:2,T57:3, T58:0,T58:1,T58:2,T58:3, T59:0,T59:1,T59:2,T59:3\n" + "T60:0,T60:1,T60:2,T60:3, T61:0,T61:1,T61:2,T61:3, T62:0,T62:1,T62:2,T62:3, T63:0,T63:1,T63:2,T63:3\n"; + // clang-format on + + runBlocked2dSingleCTA(/*row=*/16, /*col=*/16, /*sizePerThread=*/{1, 4}, + /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{2, 1}, + /*order=*/{1, 0}, /*refStr=*/refStr); +} + +TEST_F(EmitIndicesTest, BlockedLayout_MultiCTA_CTAOrder_1_0) { + // clang-format off + std::string refStr = + "CTA0: T0:0,CTA0: T1:0,CTA0: T2:0,CTA0: T3:0, CTA1: T0:0,CTA1: T1:0,CTA1: T2:0,CTA1: T3:0\n" + "CTA0: T4:0,CTA0: T5:0,CTA0: T6:0,CTA0: T7:0, CTA1: T4:0,CTA1: T5:0,CTA1: T6:0,CTA1: T7:0\n" + "CTA0: T8:0,CTA0: T9:0,CTA0:T10:0,CTA0:T11:0, CTA1: T8:0,CTA1: T9:0,CTA1:T10:0,CTA1:T11:0\n" + "CTA0:T12:0,CTA0:T13:0,CTA0:T14:0,CTA0:T15:0, CTA1:T12:0,CTA1:T13:0,CTA1:T14:0,CTA1:T15:0\n" + "CTA0:T16:0,CTA0:T17:0,CTA0:T18:0,CTA0:T19:0, CTA1:T16:0,CTA1:T17:0,CTA1:T18:0,CTA1:T19:0\n" + "CTA0:T20:0,CTA0:T21:0,CTA0:T22:0,CTA0:T23:0, CTA1:T20:0,CTA1:T21:0,CTA1:T22:0,CTA1:T23:0\n" + "CTA0:T24:0,CTA0:T25:0,CTA0:T26:0,CTA0:T27:0, CTA1:T24:0,CTA1:T25:0,CTA1:T26:0,CTA1:T27:0\n" + "CTA0:T28:0,CTA0:T29:0,CTA0:T30:0,CTA0:T31:0, CTA1:T28:0,CTA1:T29:0,CTA1:T30:0,CTA1:T31:0\n" + + "CTA2: T0:0,CTA2: T1:0,CTA2: T2:0,CTA2: T3:0, CTA3: T0:0,CTA3: T1:0,CTA3: T2:0,CTA3: T3:0\n" + "CTA2: T4:0,CTA2: T5:0,CTA2: T6:0,CTA2: T7:0, CTA3: T4:0,CTA3: T5:0,CTA3: T6:0,CTA3: T7:0\n" + "CTA2: T8:0,CTA2: T9:0,CTA2:T10:0,CTA2:T11:0, CTA3: T8:0,CTA3: T9:0,CTA3:T10:0,CTA3:T11:0\n" + "CTA2:T12:0,CTA2:T13:0,CTA2:T14:0,CTA2:T15:0, CTA3:T12:0,CTA3:T13:0,CTA3:T14:0,CTA3:T15:0\n" + "CTA2:T16:0,CTA2:T17:0,CTA2:T18:0,CTA2:T19:0, CTA3:T16:0,CTA3:T17:0,CTA3:T18:0,CTA3:T19:0\n" + "CTA2:T20:0,CTA2:T21:0,CTA2:T22:0,CTA2:T23:0, CTA3:T20:0,CTA3:T21:0,CTA3:T22:0,CTA3:T23:0\n" + "CTA2:T24:0,CTA2:T25:0,CTA2:T26:0,CTA2:T27:0, CTA3:T24:0,CTA3:T25:0,CTA3:T26:0,CTA3:T27:0\n" + "CTA2:T28:0,CTA2:T29:0,CTA2:T30:0,CTA2:T31:0, CTA3:T28:0,CTA3:T29:0,CTA3:T30:0,CTA3:T31:0\n"; + // clang-format on + + runBlocked2dMultiCTA(/*row=*/16, /*col=*/8, /*sizePerThread=*/{1, 1}, + /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{1, 1}, + /*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2}, + /*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0}, + /*refStr=*/refStr); +} + +TEST_F(EmitIndicesTest, BlockedLayout_MultiCTA_CTAOrder_0_1) { + // clang-format off + std::string refStr = + "CTA0: T0:0,CTA0: T1:0,CTA0: T2:0,CTA0: T3:0, CTA2: T0:0,CTA2: T1:0,CTA2: T2:0,CTA2: T3:0\n" + "CTA0: T4:0,CTA0: T5:0,CTA0: T6:0,CTA0: T7:0, CTA2: T4:0,CTA2: T5:0,CTA2: T6:0,CTA2: T7:0\n" + "CTA0: T8:0,CTA0: T9:0,CTA0:T10:0,CTA0:T11:0, CTA2: T8:0,CTA2: T9:0,CTA2:T10:0,CTA2:T11:0\n" + "CTA0:T12:0,CTA0:T13:0,CTA0:T14:0,CTA0:T15:0, CTA2:T12:0,CTA2:T13:0,CTA2:T14:0,CTA2:T15:0\n" + "CTA0:T16:0,CTA0:T17:0,CTA0:T18:0,CTA0:T19:0, CTA2:T16:0,CTA2:T17:0,CTA2:T18:0,CTA2:T19:0\n" + "CTA0:T20:0,CTA0:T21:0,CTA0:T22:0,CTA0:T23:0, CTA2:T20:0,CTA2:T21:0,CTA2:T22:0,CTA2:T23:0\n" + "CTA0:T24:0,CTA0:T25:0,CTA0:T26:0,CTA0:T27:0, CTA2:T24:0,CTA2:T25:0,CTA2:T26:0,CTA2:T27:0\n" + "CTA0:T28:0,CTA0:T29:0,CTA0:T30:0,CTA0:T31:0, CTA2:T28:0,CTA2:T29:0,CTA2:T30:0,CTA2:T31:0\n" + + "CTA1: T0:0,CTA1: T1:0,CTA1: T2:0,CTA1: T3:0, CTA3: T0:0,CTA3: T1:0,CTA3: T2:0,CTA3: T3:0\n" + "CTA1: T4:0,CTA1: T5:0,CTA1: T6:0,CTA1: T7:0, CTA3: T4:0,CTA3: T5:0,CTA3: T6:0,CTA3: T7:0\n" + "CTA1: T8:0,CTA1: T9:0,CTA1:T10:0,CTA1:T11:0, CTA3: T8:0,CTA3: T9:0,CTA3:T10:0,CTA3:T11:0\n" + "CTA1:T12:0,CTA1:T13:0,CTA1:T14:0,CTA1:T15:0, CTA3:T12:0,CTA3:T13:0,CTA3:T14:0,CTA3:T15:0\n" + "CTA1:T16:0,CTA1:T17:0,CTA1:T18:0,CTA1:T19:0, CTA3:T16:0,CTA3:T17:0,CTA3:T18:0,CTA3:T19:0\n" + "CTA1:T20:0,CTA1:T21:0,CTA1:T22:0,CTA1:T23:0, CTA3:T20:0,CTA3:T21:0,CTA3:T22:0,CTA3:T23:0\n" + "CTA1:T24:0,CTA1:T25:0,CTA1:T26:0,CTA1:T27:0, CTA3:T24:0,CTA3:T25:0,CTA3:T26:0,CTA3:T27:0\n" + "CTA1:T28:0,CTA1:T29:0,CTA1:T30:0,CTA1:T31:0, CTA3:T28:0,CTA3:T29:0,CTA3:T30:0,CTA3:T31:0\n"; + // clang-format on + + runBlocked2dMultiCTA(/*row=*/16, /*col=*/8, /*sizePerThread=*/{1, 1}, + /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{1, 1}, + /*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2}, + /*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{0, 1}, + /*refStr=*/refStr); +} + +TEST_F(EmitIndicesTest, BlockedLayout_MultiCTA_CTAWrap_Dim1) { + // clang-format off + std::string refStr = + "CTA0: T0:0|CTA1: T0:0, CTA0: T1:0|CTA1: T1:0, CTA0: T2:0|CTA1: T2:0, CTA0: T3:0|CTA1: T3:0\n" + "CTA0: T4:0|CTA1: T4:0, CTA0: T5:0|CTA1: T5:0, CTA0: T6:0|CTA1: T6:0, CTA0: T7:0|CTA1: T7:0\n" + "CTA0: T8:0|CTA1: T8:0, CTA0: T9:0|CTA1: T9:0, CTA0:T10:0|CTA1:T10:0, CTA0:T11:0|CTA1:T11:0\n" + "CTA0:T12:0|CTA1:T12:0, CTA0:T13:0|CTA1:T13:0, CTA0:T14:0|CTA1:T14:0, CTA0:T15:0|CTA1:T15:0\n" + "CTA0:T16:0|CTA1:T16:0, CTA0:T17:0|CTA1:T17:0, CTA0:T18:0|CTA1:T18:0, CTA0:T19:0|CTA1:T19:0\n" + "CTA0:T20:0|CTA1:T20:0, CTA0:T21:0|CTA1:T21:0, CTA0:T22:0|CTA1:T22:0, CTA0:T23:0|CTA1:T23:0\n" + "CTA0:T24:0|CTA1:T24:0, CTA0:T25:0|CTA1:T25:0, CTA0:T26:0|CTA1:T26:0, CTA0:T27:0|CTA1:T27:0\n" + "CTA0:T28:0|CTA1:T28:0, CTA0:T29:0|CTA1:T29:0, CTA0:T30:0|CTA1:T30:0, CTA0:T31:0|CTA1:T31:0\n" + + "CTA2: T0:0|CTA3: T0:0, CTA2: T1:0|CTA3: T1:0, CTA2: T2:0|CTA3: T2:0, CTA2: T3:0|CTA3: T3:0\n" + "CTA2: T4:0|CTA3: T4:0, CTA2: T5:0|CTA3: T5:0, CTA2: T6:0|CTA3: T6:0, CTA2: T7:0|CTA3: T7:0\n" + "CTA2: T8:0|CTA3: T8:0, CTA2: T9:0|CTA3: T9:0, CTA2:T10:0|CTA3:T10:0, CTA2:T11:0|CTA3:T11:0\n" + "CTA2:T12:0|CTA3:T12:0, CTA2:T13:0|CTA3:T13:0, CTA2:T14:0|CTA3:T14:0, CTA2:T15:0|CTA3:T15:0\n" + "CTA2:T16:0|CTA3:T16:0, CTA2:T17:0|CTA3:T17:0, CTA2:T18:0|CTA3:T18:0, CTA2:T19:0|CTA3:T19:0\n" + "CTA2:T20:0|CTA3:T20:0, CTA2:T21:0|CTA3:T21:0, CTA2:T22:0|CTA3:T22:0, CTA2:T23:0|CTA3:T23:0\n" + "CTA2:T24:0|CTA3:T24:0, CTA2:T25:0|CTA3:T25:0, CTA2:T26:0|CTA3:T26:0, CTA2:T27:0|CTA3:T27:0\n" + "CTA2:T28:0|CTA3:T28:0, CTA2:T29:0|CTA3:T29:0, CTA2:T30:0|CTA3:T30:0, CTA2:T31:0|CTA3:T31:0\n"; + // clang-format on + + runBlocked2dMultiCTA(/*row=*/16, /*col=*/4, /*sizePerThread=*/{1, 1}, + /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{1, 1}, + /*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2}, + /*CTASplitNum=*/{2, 1}, /*CTAOrder=*/{1, 0}, + /*refStr=*/refStr); +} + +TEST_F(EmitIndicesTest, BlockedLayout_MultiCTA_CTAWrap_Dim0) { + // clang-format off + std::string refStr = + "CTA0: T0:0|CTA2: T0:0,CTA0: T1:0|CTA2: T1:0,CTA0: T2:0|CTA2: T2:0,CTA0: T3:0|CTA2: T3:0, CTA1: T0:0|CTA3: T0:0,CTA1: T1:0|CTA3: T1:0,CTA1: T2:0|CTA3: T2:0,CTA1: T3:0|CTA3: T3:0\n" + "CTA0: T4:0|CTA2: T4:0,CTA0: T5:0|CTA2: T5:0,CTA0: T6:0|CTA2: T6:0,CTA0: T7:0|CTA2: T7:0, CTA1: T4:0|CTA3: T4:0,CTA1: T5:0|CTA3: T5:0,CTA1: T6:0|CTA3: T6:0,CTA1: T7:0|CTA3: T7:0\n" + "CTA0: T8:0|CTA2: T8:0,CTA0: T9:0|CTA2: T9:0,CTA0:T10:0|CTA2:T10:0,CTA0:T11:0|CTA2:T11:0, CTA1: T8:0|CTA3: T8:0,CTA1: T9:0|CTA3: T9:0,CTA1:T10:0|CTA3:T10:0,CTA1:T11:0|CTA3:T11:0\n" + "CTA0:T12:0|CTA2:T12:0,CTA0:T13:0|CTA2:T13:0,CTA0:T14:0|CTA2:T14:0,CTA0:T15:0|CTA2:T15:0, CTA1:T12:0|CTA3:T12:0,CTA1:T13:0|CTA3:T13:0,CTA1:T14:0|CTA3:T14:0,CTA1:T15:0|CTA3:T15:0\n" + "CTA0:T16:0|CTA2:T16:0,CTA0:T17:0|CTA2:T17:0,CTA0:T18:0|CTA2:T18:0,CTA0:T19:0|CTA2:T19:0, CTA1:T16:0|CTA3:T16:0,CTA1:T17:0|CTA3:T17:0,CTA1:T18:0|CTA3:T18:0,CTA1:T19:0|CTA3:T19:0\n" + "CTA0:T20:0|CTA2:T20:0,CTA0:T21:0|CTA2:T21:0,CTA0:T22:0|CTA2:T22:0,CTA0:T23:0|CTA2:T23:0, CTA1:T20:0|CTA3:T20:0,CTA1:T21:0|CTA3:T21:0,CTA1:T22:0|CTA3:T22:0,CTA1:T23:0|CTA3:T23:0\n" + "CTA0:T24:0|CTA2:T24:0,CTA0:T25:0|CTA2:T25:0,CTA0:T26:0|CTA2:T26:0,CTA0:T27:0|CTA2:T27:0, CTA1:T24:0|CTA3:T24:0,CTA1:T25:0|CTA3:T25:0,CTA1:T26:0|CTA3:T26:0,CTA1:T27:0|CTA3:T27:0\n" + "CTA0:T28:0|CTA2:T28:0,CTA0:T29:0|CTA2:T29:0,CTA0:T30:0|CTA2:T30:0,CTA0:T31:0|CTA2:T31:0, CTA1:T28:0|CTA3:T28:0,CTA1:T29:0|CTA3:T29:0,CTA1:T30:0|CTA3:T30:0,CTA1:T31:0|CTA3:T31:0\n"; + // clang-format on + + runBlocked2dMultiCTA( + /*row=*/8, /*col=*/8, /*sizePerThread=*/{1, 1}, /*threadsPerWarp=*/{8, 4}, + /*warpsPerCTA=*/{1, 1}, /*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2}, + /*CTASplitNum=*/{1, 2}, /*CTAOrder=*/{1, 0}, /*refStr=*/refStr); +} + +TEST_F(EmitIndicesTest, BlockedLayout_MultiCTA_CTAWrapBeforeBroadcast_Dim1) { + // clang-format off + std::string refStr = + "CTA0: T0:0|CTA0: T1:0|CTA0: T2:0|CTA0: T3:0 | CTA1: T0:0|CTA1: T1:0|CTA1: T2:0|CTA1: T3:0\n" + "CTA0: T4:0|CTA0: T5:0|CTA0: T6:0|CTA0: T7:0 | CTA1: T4:0|CTA1: T5:0|CTA1: T6:0|CTA1: T7:0\n" + "CTA0: T8:0|CTA0: T9:0|CTA0:T10:0|CTA0:T11:0 | CTA1: T8:0|CTA1: T9:0|CTA1:T10:0|CTA1:T11:0\n" + "CTA0:T12:0|CTA0:T13:0|CTA0:T14:0|CTA0:T15:0 | CTA1:T12:0|CTA1:T13:0|CTA1:T14:0|CTA1:T15:0\n" + "CTA0:T16:0|CTA0:T17:0|CTA0:T18:0|CTA0:T19:0 | CTA1:T16:0|CTA1:T17:0|CTA1:T18:0|CTA1:T19:0\n" + "CTA0:T20:0|CTA0:T21:0|CTA0:T22:0|CTA0:T23:0 | CTA1:T20:0|CTA1:T21:0|CTA1:T22:0|CTA1:T23:0\n" + "CTA0:T24:0|CTA0:T25:0|CTA0:T26:0|CTA0:T27:0 | CTA1:T24:0|CTA1:T25:0|CTA1:T26:0|CTA1:T27:0\n" + "CTA0:T28:0|CTA0:T29:0|CTA0:T30:0|CTA0:T31:0 | CTA1:T28:0|CTA1:T29:0|CTA1:T30:0|CTA1:T31:0\n" + + "CTA2: T0:0|CTA2: T1:0|CTA2: T2:0|CTA2: T3:0 | CTA3: T0:0|CTA3: T1:0|CTA3: T2:0|CTA3: T3:0\n" + "CTA2: T4:0|CTA2: T5:0|CTA2: T6:0|CTA2: T7:0 | CTA3: T4:0|CTA3: T5:0|CTA3: T6:0|CTA3: T7:0\n" + "CTA2: T8:0|CTA2: T9:0|CTA2:T10:0|CTA2:T11:0 | CTA3: T8:0|CTA3: T9:0|CTA3:T10:0|CTA3:T11:0\n" + "CTA2:T12:0|CTA2:T13:0|CTA2:T14:0|CTA2:T15:0 | CTA3:T12:0|CTA3:T13:0|CTA3:T14:0|CTA3:T15:0\n" + "CTA2:T16:0|CTA2:T17:0|CTA2:T18:0|CTA2:T19:0 | CTA3:T16:0|CTA3:T17:0|CTA3:T18:0|CTA3:T19:0\n" + "CTA2:T20:0|CTA2:T21:0|CTA2:T22:0|CTA2:T23:0 | CTA3:T20:0|CTA3:T21:0|CTA3:T22:0|CTA3:T23:0\n" + "CTA2:T24:0|CTA2:T25:0|CTA2:T26:0|CTA2:T27:0 | CTA3:T24:0|CTA3:T25:0|CTA3:T26:0|CTA3:T27:0\n" + "CTA2:T28:0|CTA2:T29:0|CTA2:T30:0|CTA2:T31:0 | CTA3:T28:0|CTA3:T29:0|CTA3:T30:0|CTA3:T31:0\n"; + // clang-format on + + runBlocked2dMultiCTA(/*row=*/16, /*col=*/1, /*sizePerThread=*/{1, 1}, + /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{1, 1}, + /*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2}, + /*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0}, + /*refStr=*/refStr); +} + +TEST_F(EmitIndicesTest, BlockedLayout_MultiCTA_CTAWrapBeforeBroadcast_Dim0) { + // clang-format off + std::string refStr = + "CTA0:T0:0|CTA0: T8:0|CTA0:T16:0|CTA0:T24:0 | CTA2:T0:0|CTA2: T8:0|CTA2:T16:0|CTA2:T24:0," + "CTA0:T1:0|CTA0: T9:0|CTA0:T17:0|CTA0:T25:0 | CTA2:T1:0|CTA2: T9:0|CTA2:T17:0|CTA2:T25:0," + "CTA0:T2:0|CTA0:T10:0|CTA0:T18:0|CTA0:T26:0 | CTA2:T2:0|CTA2:T10:0|CTA2:T18:0|CTA2:T26:0," + "CTA0:T3:0|CTA0:T11:0|CTA0:T19:0|CTA0:T27:0 | CTA2:T3:0|CTA2:T11:0|CTA2:T19:0|CTA2:T27:0," + "CTA0:T4:0|CTA0:T12:0|CTA0:T20:0|CTA0:T28:0 | CTA2:T4:0|CTA2:T12:0|CTA2:T20:0|CTA2:T28:0," + "CTA0:T5:0|CTA0:T13:0|CTA0:T21:0|CTA0:T29:0 | CTA2:T5:0|CTA2:T13:0|CTA2:T21:0|CTA2:T29:0," + "CTA0:T6:0|CTA0:T14:0|CTA0:T22:0|CTA0:T30:0 | CTA2:T6:0|CTA2:T14:0|CTA2:T22:0|CTA2:T30:0," + "CTA0:T7:0|CTA0:T15:0|CTA0:T23:0|CTA0:T31:0 | CTA2:T7:0|CTA2:T15:0|CTA2:T23:0|CTA2:T31:0," + + "CTA1:T0:0|CTA1: T8:0|CTA1:T16:0|CTA1:T24:0 | CTA3:T0:0|CTA3: T8:0|CTA3:T16:0|CTA3:T24:0," + "CTA1:T1:0|CTA1: T9:0|CTA1:T17:0|CTA1:T25:0 | CTA3:T1:0|CTA3: T9:0|CTA3:T17:0|CTA3:T25:0," + "CTA1:T2:0|CTA1:T10:0|CTA1:T18:0|CTA1:T26:0 | CTA3:T2:0|CTA3:T10:0|CTA3:T18:0|CTA3:T26:0," + "CTA1:T3:0|CTA1:T11:0|CTA1:T19:0|CTA1:T27:0 | CTA3:T3:0|CTA3:T11:0|CTA3:T19:0|CTA3:T27:0," + "CTA1:T4:0|CTA1:T12:0|CTA1:T20:0|CTA1:T28:0 | CTA3:T4:0|CTA3:T12:0|CTA3:T20:0|CTA3:T28:0," + "CTA1:T5:0|CTA1:T13:0|CTA1:T21:0|CTA1:T29:0 | CTA3:T5:0|CTA3:T13:0|CTA3:T21:0|CTA3:T29:0," + "CTA1:T6:0|CTA1:T14:0|CTA1:T22:0|CTA1:T30:0 | CTA3:T6:0|CTA3:T14:0|CTA3:T22:0|CTA3:T30:0," + "CTA1:T7:0|CTA1:T15:0|CTA1:T23:0|CTA1:T31:0 | CTA3:T7:0|CTA3:T15:0|CTA3:T23:0|CTA3:T31:0\n"; + // clang-format on + + runBlocked2dMultiCTA(/*row=*/1, /*col=*/16, /*sizePerThread=*/{1, 1}, + /*threadsPerWarp=*/{4, 8}, /*warpsPerCTA=*/{1, 1}, + /*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2}, + /*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0}, + /*refStr=*/refStr); +} + +//===----------------------------------------------------------------------===// +// Tests for SliceEncodingAttr +//===----------------------------------------------------------------------===// + +TEST_F(EmitIndicesTest, SliceLayout_SingleCTA_SliceDim1) { + // clang-format off + std::string refStr = + " T0:0| T1:0| T2:0| T3:0| T4:0| T5:0| T6:0| T7:0," + " T8:0| T9:0|T10:0|T11:0|T12:0|T13:0|T14:0|T15:0," + "T16:0|T17:0|T18:0|T19:0|T20:0|T21:0|T22:0|T23:0," + "T24:0|T25:0|T26:0|T27:0|T28:0|T29:0|T30:0|T31:0\n"; + // clang-format on + + runSliceBlockedSingleCTA(/*size=*/4, /*sizePerThread=*/{1, 1}, + /*threadsPerWarp=*/{4, 8}, /*warpsPerCTA=*/{1, 1}, + /*order=*/{1, 0}, /*sliceDim=*/1, /*refStr=*/refStr); +} + +TEST_F(EmitIndicesTest, SliceLayout_SingleCTA_SliceDim0) { + // clang-format off + std::string refStr = + "T0:0| T8:0|T16:0|T24:0," + "T1:0| T9:0|T17:0|T25:0," + "T2:0|T10:0|T18:0|T26:0," + "T3:0|T11:0|T19:0|T27:0," + "T4:0|T12:0|T20:0|T28:0," + "T5:0|T13:0|T21:0|T29:0," + "T6:0|T14:0|T22:0|T30:0," + "T7:0|T15:0|T23:0|T31:0\n"; + // clang-format on + + runSliceBlockedSingleCTA(/*size=*/8, /*sizePerThread=*/{1, 1}, + /*threadsPerWarp=*/{4, 8}, /*warpsPerCTA=*/{1, 1}, + /*order=*/{1, 0}, /*sliceDim=*/0, /*refStr=*/refStr); +} + +TEST_F(EmitIndicesTest, SliceLayout_MultiCTA) { + // clang-format off + std::string refStr = + "CTA0: T0:0|CTA0: T1:0|CTA0: T2:0|CTA0: T3:0 | CTA1: T0:0|CTA1: T1:0|CTA1: T2:0|CTA1: T3:0," + "CTA0: T4:0|CTA0: T5:0|CTA0: T6:0|CTA0: T7:0 | CTA1: T4:0|CTA1: T5:0|CTA1: T6:0|CTA1: T7:0," + "CTA0: T8:0|CTA0: T9:0|CTA0:T10:0|CTA0:T11:0 | CTA1: T8:0|CTA1: T9:0|CTA1:T10:0|CTA1:T11:0," + "CTA0:T12:0|CTA0:T13:0|CTA0:T14:0|CTA0:T15:0 | CTA1:T12:0|CTA1:T13:0|CTA1:T14:0|CTA1:T15:0," + "CTA0:T16:0|CTA0:T17:0|CTA0:T18:0|CTA0:T19:0 | CTA1:T16:0|CTA1:T17:0|CTA1:T18:0|CTA1:T19:0," + "CTA0:T20:0|CTA0:T21:0|CTA0:T22:0|CTA0:T23:0 | CTA1:T20:0|CTA1:T21:0|CTA1:T22:0|CTA1:T23:0," + "CTA0:T24:0|CTA0:T25:0|CTA0:T26:0|CTA0:T27:0 | CTA1:T24:0|CTA1:T25:0|CTA1:T26:0|CTA1:T27:0," + "CTA0:T28:0|CTA0:T29:0|CTA0:T30:0|CTA0:T31:0 | CTA1:T28:0|CTA1:T29:0|CTA1:T30:0|CTA1:T31:0," + + "CTA2: T0:0|CTA2: T1:0|CTA2: T2:0|CTA2: T3:0 | CTA3: T0:0|CTA3: T1:0|CTA3: T2:0|CTA3: T3:0," + "CTA2: T4:0|CTA2: T5:0|CTA2: T6:0|CTA2: T7:0 | CTA3: T4:0|CTA3: T5:0|CTA3: T6:0|CTA3: T7:0," + "CTA2: T8:0|CTA2: T9:0|CTA2:T10:0|CTA2:T11:0 | CTA3: T8:0|CTA3: T9:0|CTA3:T10:0|CTA3:T11:0," + "CTA2:T12:0|CTA2:T13:0|CTA2:T14:0|CTA2:T15:0 | CTA3:T12:0|CTA3:T13:0|CTA3:T14:0|CTA3:T15:0," + "CTA2:T16:0|CTA2:T17:0|CTA2:T18:0|CTA2:T19:0 | CTA3:T16:0|CTA3:T17:0|CTA3:T18:0|CTA3:T19:0," + "CTA2:T20:0|CTA2:T21:0|CTA2:T22:0|CTA2:T23:0 | CTA3:T20:0|CTA3:T21:0|CTA3:T22:0|CTA3:T23:0," + "CTA2:T24:0|CTA2:T25:0|CTA2:T26:0|CTA2:T27:0 | CTA3:T24:0|CTA3:T25:0|CTA3:T26:0|CTA3:T27:0," + "CTA2:T28:0|CTA2:T29:0|CTA2:T30:0|CTA2:T31:0 | CTA3:T28:0|CTA3:T29:0|CTA3:T30:0|CTA3:T31:0\n"; + // clang-format on + + runSliceBlockedMultiCTA(/*size=*/16, /*sizePerThread=*/{1, 1}, + /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{1, 1}, + /*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2}, + /*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0}, + /*sliceDim=*/1, /*refStr=*/refStr); +} + +//===----------------------------------------------------------------------===// +// Tests for MmaEncodingAttr +//===----------------------------------------------------------------------===// + +TEST_F(EmitIndicesTest, MmaLayout) { + // clang-format off + std::string refStr = + " T0:0, T0:1, T1:0, T1:1, T2:0, T2:1, T3:0, T3:1\n" + " T4:0, T4:1, T5:0, T5:1, T6:0, T6:1, T7:0, T7:1\n" + " T8:0, T8:1, T9:0, T9:1,T10:0,T10:1,T11:0,T11:1\n" + "T12:0,T12:1,T13:0,T13:1,T14:0,T14:1,T15:0,T15:1\n" + "T16:0,T16:1,T17:0,T17:1,T18:0,T18:1,T19:0,T19:1\n" + "T20:0,T20:1,T21:0,T21:1,T22:0,T22:1,T23:0,T23:1\n" + "T24:0,T24:1,T25:0,T25:1,T26:0,T26:1,T27:0,T27:1\n" + "T28:0,T28:1,T29:0,T29:1,T30:0,T30:1,T31:0,T31:1\n" + " T0:2, T0:3, T1:2, T1:3, T2:2, T2:3, T3:2, T3:3\n" + " T4:2, T4:3, T5:2, T5:3, T6:2, T6:3, T7:2, T7:3\n" + " T8:2, T8:3, T9:2, T9:3,T10:2,T10:3,T11:2,T11:3\n" + "T12:2,T12:3,T13:2,T13:3,T14:2,T14:3,T15:2,T15:3\n" + "T16:2,T16:3,T17:2,T17:3,T18:2,T18:3,T19:2,T19:3\n" + "T20:2,T20:3,T21:2,T21:3,T22:2,T22:3,T23:2,T23:3\n" + "T24:2,T24:3,T25:2,T25:3,T26:2,T26:3,T27:2,T27:3\n" + "T28:2,T28:3,T29:2,T29:3,T30:2,T30:3,T31:2,T31:3\n"; + // clang-format on + + runMmaSingleCTA(/*row=*/16, /*col=*/8, /*versionMajor=*/2, /*versionMinor=*/1, + /*warpsPerCTA=*/{1, 1}, /*instrShape=*/{16, 8}, + /*refStr=*/refStr); +} + +//===----------------------------------------------------------------------===// +// Tests for SharedEncodingAttr +//===----------------------------------------------------------------------===// + +TEST_F(EmitIndicesTest, SharedLayout) { + // clang-format off + std::string refStr = + "(0: 0),(0: 1),(0: 2),(0: 3),(0: 4),(0: 5),(0: 6),(0: 7),(0: 8),(0: 9),(0:10),(0:11),(0:12),(0:13),(0:14),(0:15),(0:16),(0:17),(0:18),(0:19),(0:20),(0:21),(0:22),(0:23),(0:24),(0:25),(0:26),(0:27),(0:28),(0:29),(0:30),(0:31)\n" + "(1: 0),(1: 1),(1: 2),(1: 3),(1: 4),(1: 5),(1: 6),(1: 7),(1: 8),(1: 9),(1:10),(1:11),(1:12),(1:13),(1:14),(1:15),(1:16),(1:17),(1:18),(1:19),(1:20),(1:21),(1:22),(1:23),(1:24),(1:25),(1:26),(1:27),(1:28),(1:29),(1:30),(1:31)\n" + "(2: 8),(2: 9),(2:10),(2:11),(2:12),(2:13),(2:14),(2:15),(2: 0),(2: 1),(2: 2),(2: 3),(2: 4),(2: 5),(2: 6),(2: 7),(2:24),(2:25),(2:26),(2:27),(2:28),(2:29),(2:30),(2:31),(2:16),(2:17),(2:18),(2:19),(2:20),(2:21),(2:22),(2:23)\n" + "(3: 8),(3: 9),(3:10),(3:11),(3:12),(3:13),(3:14),(3:15),(3: 0),(3: 1),(3: 2),(3: 3),(3: 4),(3: 5),(3: 6),(3: 7),(3:24),(3:25),(3:26),(3:27),(3:28),(3:29),(3:30),(3:31),(3:16),(3:17),(3:18),(3:19),(3:20),(3:21),(3:22),(3:23)\n" + "(4:16),(4:17),(4:18),(4:19),(4:20),(4:21),(4:22),(4:23),(4:24),(4:25),(4:26),(4:27),(4:28),(4:29),(4:30),(4:31),(4: 0),(4: 1),(4: 2),(4: 3),(4: 4),(4: 5),(4: 6),(4: 7),(4: 8),(4: 9),(4:10),(4:11),(4:12),(4:13),(4:14),(4:15)\n" + "(5:16),(5:17),(5:18),(5:19),(5:20),(5:21),(5:22),(5:23),(5:24),(5:25),(5:26),(5:27),(5:28),(5:29),(5:30),(5:31),(5: 0),(5: 1),(5: 2),(5: 3),(5: 4),(5: 5),(5: 6),(5: 7),(5: 8),(5: 9),(5:10),(5:11),(5:12),(5:13),(5:14),(5:15)\n" + "(6:24),(6:25),(6:26),(6:27),(6:28),(6:29),(6:30),(6:31),(6:16),(6:17),(6:18),(6:19),(6:20),(6:21),(6:22),(6:23),(6: 8),(6: 9),(6:10),(6:11),(6:12),(6:13),(6:14),(6:15),(6: 0),(6: 1),(6: 2),(6: 3),(6: 4),(6: 5),(6: 6),(6: 7)\n" + "(7:24),(7:25),(7:26),(7:27),(7:28),(7:29),(7:30),(7:31),(7:16),(7:17),(7:18),(7:19),(7:20),(7:21),(7:22),(7:23),(7: 8),(7: 9),(7:10),(7:11),(7:12),(7:13),(7:14),(7:15),(7: 0),(7: 1),(7: 2),(7: 3),(7: 4),(7: 5),(7: 6),(7: 7)\n"; + // clang-format on + + runSharedSingleCTA(/*row=*/8, /*col=*/32, /*rowMajor=*/true, + /*elemTyStr=*/"F16", /*refStr=*/refStr); +} + +//===----------------------------------------------------------------------===// +// The following unittests are tools for Triton developers to visualize layouts. +// You can modify parameters and shapes here to create your own layout and +// tensor. The output will be saved into a csv file which can be opened with +// Microsoft Excel. +//===----------------------------------------------------------------------===// + +TEST_F(EmitIndicesTest, LayoutVisualizer_Blocked) { + CTALayoutAttr CTALayout = + CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{2, 2}, + /*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0}); + + Attribute blockedLayout = BlockedEncodingAttr::get( + /*context=*/&context, /*sizePerThread=*/{1, 4}, + /*threadsPerWarp=*/{2, 16}, + /*warpsPerCTA=*/{4, 1}, /*order=*/{1, 0}, /*CTALayout=*/CTALayout); + + llvm::SmallVector shape = {/*row=*/128, /*col=*/128}; + + std::ofstream ofs("blockedLayout.csv"); + ofs << dumpDistributedLayout(blockedLayout, shape, /*multiCTA=*/true); +} + +TEST_F(EmitIndicesTest, LayoutVisualizer_Slice) { + CTALayoutAttr CTALayout = + CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{1, 1}, + /*CTASplitNum=*/{1, 1}, /*CTAOrder=*/{1, 0}); + + Attribute blockedLayout = BlockedEncodingAttr::get( + /*context=*/&context, /*sizePerThread=*/{1, 1}, /*threadsPerWarp=*/{4, 8}, + /*warpsPerCTA=*/{1, 1}, /*order=*/{1, 0}, /*CTALayout=*/CTALayout); + + Attribute sliceLayout = SliceEncodingAttr::get( + /*context=*/&context, /*dim=*/1, /*parent=*/blockedLayout); + + llvm::SmallVector shape = {4}; + + std::ofstream ofs("sliceLayout.csv"); + ofs << dumpDistributedLayout(sliceLayout, shape, /*multiCTA=*/false); +} + +TEST_F(EmitIndicesTest, LayoutVisualizer_Mma) { + CTALayoutAttr CTALayout = + CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{1, 1}, + /*CTASplitNum=*/{1, 1}, /*CTAOrder=*/{1, 0}); + + Attribute mmaLayout = MmaEncodingAttr::get( + /*context=*/&context, /*versionMajor=*/2, /*versionMinor=*/1, + /*warpsPerCTA=*/{1, 1}, /*CTALayout=*/CTALayout, /*instrShape=*/{16, 8}); + + llvm::SmallVector shape = {/*row=*/16, /*col=*/8}; + + std::ofstream ofs("mmaLayout.csv"); + ofs << dumpDistributedLayout(mmaLayout, shape, /*multiCTA=*/false); +} + +TEST_F(EmitIndicesTest, LayoutVisualizer_Shared) { + CTALayoutAttr CTALayout = + CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{1, 1}, + /*CTASplitNum=*/{1, 1}, /*CTAOrder=*/{1, 0}); + + Attribute sharedLayout = SharedEncodingAttr::get( + /*context=*/&context, /*vec=*/1, /*perPhase=*/2, /*maxPhase=*/8, + /*order=*/{0, 1}, /*CTALayout=*/CTALayout); + + llvm::SmallVector shape = {/*row=*/16, /*col=*/16}; + Type elemTy = FloatType::getF16(&context); + + std::ofstream ofs("sharedLayout.csv"); + ofs << dumpSharedLayout(sharedLayout, shape, elemTy, /*multiCTA=*/false); +} + +} // namespace gpu +} // namespace triton +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Main +//===----------------------------------------------------------------------===// + +int main(int argc, char *argv[]) { + testing::InitGoogleTest(&argc, argv); + // FIXME: These tests are temporarily disabled due to ctaid.x|y|z are swapped + // return RUN_ALL_TESTS(); +} diff --git a/unittest/Dialect/TritonGPU/CMakeLists.txt b/unittest/Dialect/TritonGPU/CMakeLists.txt index a2444cfa0a92..bbd4080be9f5 100644 --- a/unittest/Dialect/TritonGPU/CMakeLists.txt +++ b/unittest/Dialect/TritonGPU/CMakeLists.txt @@ -1,5 +1,5 @@ add_triton_ut( NAME TestSwizzling SRCS SwizzleTest.cpp - LIBS TritonGPUIR ${dialect_libs} ${conversion_libs} + LIBS TritonGPUIR TritonNvidiaGPUIR ${dialect_libs} ${conversion_libs} ) diff --git a/unittest/Dialect/TritonGPU/SwizzleTest.cpp b/unittest/Dialect/TritonGPU/SwizzleTest.cpp index dc4456e87acc..2b9faeaf405a 100644 --- a/unittest/Dialect/TritonGPU/SwizzleTest.cpp +++ b/unittest/Dialect/TritonGPU/SwizzleTest.cpp @@ -27,15 +27,20 @@ TEST_P(SwizzleDotOperandTestFixture, DotOperands) { // init context MLIRContext ctx; ctx.loadDialect(); + + auto CTALayout = + triton::gpu::CTALayoutAttr::get(&ctx, {1, 1}, {1, 1}, {0, 1}); + // create encoding - auto parent = triton::gpu::MmaEncodingAttr::get(&ctx, 2, 0, {1, 1}); + auto parent = triton::gpu::MmaEncodingAttr::get(&ctx, 2, 0, {1, 1}, CTALayout, + {16, 64, 16}); auto encoding = triton::gpu::DotOperandEncodingAttr::get( &ctx, params.opIdx, parent, 32 / params.typeWidth); // create element type Type eltType = IntegerType::get(&ctx, params.typeWidth); - auto layout = - SharedEncodingAttr::get(&ctx, encoding, params.shape, {1, 0}, eltType); + auto layout = SharedEncodingAttr::get(&ctx, encoding, params.shape, {1, 0}, + CTALayout, eltType); ASSERT_EQ(layout.getVec(), params.refSwizzle.vec); ASSERT_EQ(layout.getPerPhase(), params.refSwizzle.perPhase);