diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml new file mode 100644 index 0000000..095b8f5 --- /dev/null +++ b/.github/workflows/python-app.yml @@ -0,0 +1,61 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Python application + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +permissions: + contents: read + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + path: 'triton_viz' + + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: '3.10' + + - name: Cache Dependencies + uses: actions/cache@v3 + id: cache-pip + with: + path: /opt/hostedtoolcache/Python/3.10.13/x64 + key: ${{ runner.os }}-pip-3.10-${{ hashFiles('**/setup.py') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Clone Triton and Install + run: | + git clone https://github.com/openai/triton.git + cd triton/python + pip install -e . + + - name: Install Dependencies if Cache Missed + if: steps.cache-pip.outputs.cache-hit != 'true' + run: | + cd triton_viz + pip install -e . + pre-commit install + pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 + pip uninstall pytorch-triton -y + + - name: Lint with pre-commit + run: | + cd triton_viz + pre-commit run --all-files + + - name: Test with pytest + run: | + cd triton_viz + python -m pytest Examples diff --git a/Examples/conftest.py b/Examples/conftest.py new file mode 100644 index 0000000..c326b6f --- /dev/null +++ b/Examples/conftest.py @@ -0,0 +1,8 @@ +import pytest +from triton_viz.interpreter import record_builder + + +@pytest.fixture(autouse=True, scope="function") +def clear_cache(): + yield + record_builder.reset() diff --git a/Examples/vec_add.py b/Examples/vec_add.py index b1695b1..5d5fe24 100644 --- a/Examples/vec_add.py +++ b/Examples/vec_add.py @@ -60,7 +60,7 @@ def perform_vec_add(device, size): torch.manual_seed(0) x = torch.rand(size, device=device) y = torch.rand(size, device=device) - output = add(x, y) # Assuming add() is your custom function + output, grid = add(x, y) # Assuming add() is your custom function return x, y, output diff --git a/setup.py b/setup.py index a469252..6898fed 100644 --- a/setup.py +++ b/setup.py @@ -13,5 +13,8 @@ "triton", "gradio", "chalk-diagrams @ git+https://github.com/chalk-diagrams/chalk.git", + "pyarrow", + "pre-commit", + "pytest", ], ) diff --git a/triton_viz/interface.py b/triton_viz/interface.py index 6f7a871..a60c1f7 100644 --- a/triton_viz/interface.py +++ b/triton_viz/interface.py @@ -3,7 +3,6 @@ import tempfile - def launch(): cache = {} program_records, tt = triton_viz.collect_grid() @@ -19,7 +18,7 @@ def launch(): s1 = gr.Slider(0, m[0], value=0, step=1, label="Program Id 0") s2 = gr.Slider(0, m[1], value=0, step=1, label="Program Id 1") s3 = gr.Slider(0, m[2], value=0, step=1, label="Program Id 2") - + def update(inp): a = inp[s1] b = inp[s2] diff --git a/triton_viz/interpreter.py b/triton_viz/interpreter.py index 1679e0e..f4f7c19 100644 --- a/triton_viz/interpreter.py +++ b/triton_viz/interpreter.py @@ -259,8 +259,9 @@ def wrapper(arg, axis): def _create_reduce(fn, op_name): @wraps(fn) - def wrapper(input, axis, keep_dims=False): - ret = fn(input, axis, keep_dims) + def wrapper(input, axis=None, **kwargs): + ret = fn(input, axis=axis, **kwargs) + keep_dims = kwargs.get("keep_dims", False) reduce_record = Reduce( input_shape=input.handle.data.shape, index=axis,