Skip to content

Commit

Permalink
Update triton to 310405647df51a909943bed71c5a6fd9a3e402b4 (#175)
Browse files Browse the repository at this point in the history
+ Install pybind11 during setup
+ Implement `get_module_map`
+ Fix missing argument materialization hook in StructuredToMemref
+ Update signatures in UseAnalysis

Fixes #169
  • Loading branch information
nhat-nguyen authored Sep 17, 2024
1 parent 6531ff6 commit 19b4dfc
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 16 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/test-plugin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
working-directory: triton_shared/triton/python
run: |
python3 -m pip install --upgrade pip
python3 -m pip install cmake==3.24 ninja pytest-xdist
python3 -m pip install cmake==3.24 ninja pytest-xdist pybind11
sudo apt-get update -y
sudo apt-get install -y ccache clang lld
export TRITON_PLUGIN_DIRS="${GITHUB_WORKSPACE}/triton_shared"
Expand Down Expand Up @@ -129,7 +129,7 @@ jobs:
- name: Update PATH
run: |
echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}"
- name: Check pre-commit
working-directory: triton_shared/triton
run: |
Expand All @@ -140,12 +140,12 @@ jobs:
working-directory: triton_shared/triton/python
run: |
python3 -m pip install --upgrade pip
python3 -m pip install cmake==3.24 ninja pytest-xdist
python3 -m pip install cmake==3.24 ninja pytest-xdist pybind11
sudo apt-get update -y
sudo apt-get install -y ccache clang lld
export TRITON_PLUGIN_DIRS="${GITHUB_WORKSPACE}/triton_shared"
TRITON_BUILD_WITH_CLANG_LLD=true TRITON_BUILD_WITH_CCACHE=true python3 -m pip install --no-build-isolation -vvv '.[tests]'
- name: Run shared middle-layer lit tests
working-directory: triton_shared/triton/python
run: |
Expand All @@ -155,18 +155,18 @@ jobs:
echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1
fi
lit -v "${LIT_TEST_DIR}"
- name: Install CPU backend example dependencies
run: |
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
python3 -m pip install pytest
- name: Prepare CPU backend environment
working-directory: triton_shared/triton/python
run: |
echo "TRITON_SHARED_OPT_PATH=$(pwd)/build/$(ls $(pwd)/build | grep -i cmake)/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt" >> "${GITHUB_ENV}"
echo "LLVM_BINARY_DIR=${HOME}/.triton/llvm/$(ls ${HOME}/.triton/llvm/ | grep -i llvm)/bin" >> "${GITHUB_ENV}"
- name: Run CPU backend examples
working-directory: triton_shared/python/examples
run: pytest .
7 changes: 6 additions & 1 deletion backend/compiler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from triton.backends.compiler import BaseBackend, GPUTarget
from triton._C.libtriton import ir, passes
from dataclasses import dataclass
from typing import Any, Tuple
from typing import Any, Dict, Tuple
from types import ModuleType
import hashlib
import tempfile
import os
Expand Down Expand Up @@ -192,3 +193,7 @@ def add_stages(self, stages, options):
@functools.lru_cache()
def hash(self):
return self.target

# The CPU backend does not use any extra python modules, return an empty dictionary
def get_module_map(self) -> Dict[str, ModuleType]:
return {}
4 changes: 2 additions & 2 deletions include/triton-shared/Analysis/UseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ struct UseInfo : public dataflow::AbstractSparseLattice {
class UseAnalysis : public dataflow::SparseBackwardDataFlowAnalysis<UseInfo> {
public:
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
void visitOperation(Operation *op, ArrayRef<UseInfo *> operands,
ArrayRef<const UseInfo *> results) override;
LogicalResult visitOperation(Operation *op, ArrayRef<UseInfo *> operands,
ArrayRef<const UseInfo *> results) override;

void visitBranchOperand(OpOperand &operand) override { return; }

Expand Down
7 changes: 4 additions & 3 deletions lib/Analysis/UseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ using namespace dataflow;
// Use Analysis
// Note that logic below should evolve with triton-to-affine pass
//===----------------------------------------------------------------------===//
void triton::UseAnalysis::visitOperation(Operation *op,
ArrayRef<UseInfo *> operands,
ArrayRef<const UseInfo *> results) {
LogicalResult
triton::UseAnalysis::visitOperation(Operation *op, ArrayRef<UseInfo *> operands,
ArrayRef<const UseInfo *> results) {
// If an op only produces pointer, all its operands are used as meta data.
// This accounts for scenarios such as addptr in a loop whose result is
// yielded. In this case, if the loop returns data tensors, addptr will be
Expand Down Expand Up @@ -85,6 +85,7 @@ void triton::UseAnalysis::visitOperation(Operation *op,
propagateResults(operand, results);
}
});
return success();
}

LogicalResult triton::runUseAnalysis(triton::FuncOp &funcOp) {
Expand Down
7 changes: 7 additions & 0 deletions lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ class TritonFunctionSignatureConverter : public TypeConverter {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});

addArgumentMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
}
};

Expand Down
5 changes: 3 additions & 2 deletions python/examples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def device(request):


tests_not_supported = {
"test_bin_op",
"test_split",
"test_split_to_scalar",
"test_interleave_scalars",
Expand Down Expand Up @@ -94,7 +95,7 @@ def pytest_collection_modifyitems(config, items):

for item in items:
test_func_name = item.originalname if item.originalname else item.name

if test_func_name in tests_not_supported:
item.add_marker(skip_marker)
continue
Expand All @@ -105,5 +106,5 @@ def pytest_collection_modifyitems(config, items):
item.add_marker(skip_marker_bfloat)
if param_name.startswith('input_precision') and param_value.startswith('tf32'):
item.add_marker(skip_marker_tf32)
if param_name.endswith('dtype') and ('float8' in str(param_value)):
if (param_name.startswith('dtype') or param_name.endswith('dtype')) and ('float8' in str(param_value)):
item.add_marker(skip_marker_float8)
2 changes: 1 addition & 1 deletion triton
Submodule triton updated 218 files

0 comments on commit 19b4dfc

Please sign in to comment.