From 19b4dfc355e61ca7f01adddc4f1cbad78e790e06 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Tue, 17 Sep 2024 17:40:23 -0400 Subject: [PATCH] Update triton to 310405647df51a909943bed71c5a6fd9a3e402b4 (#175) + Install pybind11 during setup + Implement `get_module_map` + Fix missing argument materialization hook in StructuredToMemref + Update signatures in UseAnalysis Fixes #169 --- .github/workflows/test-plugin.yml | 14 +++++++------- backend/compiler.py | 7 ++++++- include/triton-shared/Analysis/UseAnalysis.h | 4 ++-- lib/Analysis/UseAnalysis.cpp | 7 ++++--- .../StructuredToMemref/StructuredToMemrefPass.cpp | 7 +++++++ python/examples/conftest.py | 5 +++-- triton | 2 +- 7 files changed, 30 insertions(+), 16 deletions(-) diff --git a/.github/workflows/test-plugin.yml b/.github/workflows/test-plugin.yml index 265ab877..443405ca 100644 --- a/.github/workflows/test-plugin.yml +++ b/.github/workflows/test-plugin.yml @@ -67,7 +67,7 @@ jobs: working-directory: triton_shared/triton/python run: | python3 -m pip install --upgrade pip - python3 -m pip install cmake==3.24 ninja pytest-xdist + python3 -m pip install cmake==3.24 ninja pytest-xdist pybind11 sudo apt-get update -y sudo apt-get install -y ccache clang lld export TRITON_PLUGIN_DIRS="${GITHUB_WORKSPACE}/triton_shared" @@ -129,7 +129,7 @@ jobs: - name: Update PATH run: | echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}" - + - name: Check pre-commit working-directory: triton_shared/triton run: | @@ -140,12 +140,12 @@ jobs: working-directory: triton_shared/triton/python run: | python3 -m pip install --upgrade pip - python3 -m pip install cmake==3.24 ninja pytest-xdist + python3 -m pip install cmake==3.24 ninja pytest-xdist pybind11 sudo apt-get update -y sudo apt-get install -y ccache clang lld export TRITON_PLUGIN_DIRS="${GITHUB_WORKSPACE}/triton_shared" TRITON_BUILD_WITH_CLANG_LLD=true TRITON_BUILD_WITH_CCACHE=true python3 -m pip install --no-build-isolation -vvv '.[tests]' - + - name: Run shared middle-layer lit tests working-directory: triton_shared/triton/python run: | @@ -155,18 +155,18 @@ jobs: echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1 fi lit -v "${LIT_TEST_DIR}" - + - name: Install CPU backend example dependencies run: | python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu python3 -m pip install pytest - + - name: Prepare CPU backend environment working-directory: triton_shared/triton/python run: | echo "TRITON_SHARED_OPT_PATH=$(pwd)/build/$(ls $(pwd)/build | grep -i cmake)/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt" >> "${GITHUB_ENV}" echo "LLVM_BINARY_DIR=${HOME}/.triton/llvm/$(ls ${HOME}/.triton/llvm/ | grep -i llvm)/bin" >> "${GITHUB_ENV}" - + - name: Run CPU backend examples working-directory: triton_shared/python/examples run: pytest . diff --git a/backend/compiler.py b/backend/compiler.py index 62dd4d59..94d42eaa 100644 --- a/backend/compiler.py +++ b/backend/compiler.py @@ -1,7 +1,8 @@ from triton.backends.compiler import BaseBackend, GPUTarget from triton._C.libtriton import ir, passes from dataclasses import dataclass -from typing import Any, Tuple +from typing import Any, Dict, Tuple +from types import ModuleType import hashlib import tempfile import os @@ -192,3 +193,7 @@ def add_stages(self, stages, options): @functools.lru_cache() def hash(self): return self.target + + # The CPU backend does not use any extra python modules, return an empty dictionary + def get_module_map(self) -> Dict[str, ModuleType]: + return {} diff --git a/include/triton-shared/Analysis/UseAnalysis.h b/include/triton-shared/Analysis/UseAnalysis.h index 634888b7..39c3055a 100644 --- a/include/triton-shared/Analysis/UseAnalysis.h +++ b/include/triton-shared/Analysis/UseAnalysis.h @@ -80,8 +80,8 @@ struct UseInfo : public dataflow::AbstractSparseLattice { class UseAnalysis : public dataflow::SparseBackwardDataFlowAnalysis { public: using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; - void visitOperation(Operation *op, ArrayRef operands, - ArrayRef results) override; + LogicalResult visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override; void visitBranchOperand(OpOperand &operand) override { return; } diff --git a/lib/Analysis/UseAnalysis.cpp b/lib/Analysis/UseAnalysis.cpp index e8b53c9b..62e45080 100644 --- a/lib/Analysis/UseAnalysis.cpp +++ b/lib/Analysis/UseAnalysis.cpp @@ -25,9 +25,9 @@ using namespace dataflow; // Use Analysis // Note that logic below should evolve with triton-to-affine pass //===----------------------------------------------------------------------===// -void triton::UseAnalysis::visitOperation(Operation *op, - ArrayRef operands, - ArrayRef results) { +LogicalResult +triton::UseAnalysis::visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) { // If an op only produces pointer, all its operands are used as meta data. // This accounts for scenarios such as addptr in a loop whose result is // yielded. In this case, if the loop returns data tensors, addptr will be @@ -85,6 +85,7 @@ void triton::UseAnalysis::visitOperation(Operation *op, propagateResults(operand, results); } }); + return success(); } LogicalResult triton::runUseAnalysis(triton::FuncOp &funcOp) { diff --git a/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp b/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp index fb84664a..2015903e 100644 --- a/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp +++ b/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp @@ -74,6 +74,13 @@ class TritonFunctionSignatureConverter : public TypeConverter { return builder.create(loc, resultType, inputs) .getResult(0); }); + + addArgumentMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> std::optional { + return builder.create(loc, resultType, inputs) + .getResult(0); + }); } }; diff --git a/python/examples/conftest.py b/python/examples/conftest.py index 77b4d46b..441f0033 100644 --- a/python/examples/conftest.py +++ b/python/examples/conftest.py @@ -19,6 +19,7 @@ def device(request): tests_not_supported = { + "test_bin_op", "test_split", "test_split_to_scalar", "test_interleave_scalars", @@ -94,7 +95,7 @@ def pytest_collection_modifyitems(config, items): for item in items: test_func_name = item.originalname if item.originalname else item.name - + if test_func_name in tests_not_supported: item.add_marker(skip_marker) continue @@ -105,5 +106,5 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_marker_bfloat) if param_name.startswith('input_precision') and param_value.startswith('tf32'): item.add_marker(skip_marker_tf32) - if param_name.endswith('dtype') and ('float8' in str(param_value)): + if (param_name.startswith('dtype') or param_name.endswith('dtype')) and ('float8' in str(param_value)): item.add_marker(skip_marker_float8) diff --git a/triton b/triton index a51de763..31040564 160000 --- a/triton +++ b/triton @@ -1 +1 @@ -Subproject commit a51de763012a4e3c23e9f9bf785a12d50eb490c0 +Subproject commit 310405647df51a909943bed71c5a6fd9a3e402b4