Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial changes for updated triton-viz #31

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions examples/3dims.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch
import triton
import triton.language as tl



import triton_viz
BLOCK_SIZE_X = 32
BLOCK_SIZE_Y = 8
BLOCK_SIZE_Z = 4

@triton_viz.trace
@triton.jit
def add_3d_slices_kernel(
input_ptr1, input_ptr2, output_ptr,
stride_x, stride_y, stride_z,
slice_x, slice_y, slice_z,
BLOCK_SIZE_X: tl.constexpr, BLOCK_SIZE_Y: tl.constexpr, BLOCK_SIZE_Z: tl.constexpr
):
# Compute the 3D position in the output tensor
pid_x = tl.program_id(0)
pid_y = tl.program_id(1)
pid_z = tl.program_id(2)

# Compute the starting position for this block
x_start = pid_x * BLOCK_SIZE_X
y_start = pid_y * BLOCK_SIZE_Y
z_start = pid_z * BLOCK_SIZE_Z

# Compute offsets within the block
x_offsets = x_start + tl.arange(0, BLOCK_SIZE_X)
y_offsets = y_start + tl.arange(0, BLOCK_SIZE_Y)
z_offsets = z_start + tl.arange(0, BLOCK_SIZE_Z)

# Create a mask to handle boundary conditions
mask = (x_offsets < slice_x) & (y_offsets < slice_y)[:, None] & (z_offsets < slice_z)[:, None, None]

# Compute the input and output offsets
offsets = (
z_offsets[:, None, None] * stride_z +
y_offsets[:, None] * stride_y +
x_offsets * stride_x
)

# Load input slices
slice1 = tl.load(input_ptr1 + offsets, mask=mask)
slice2 = tl.load(input_ptr2 + offsets, mask=mask)

# Perform addition
result = slice1 + slice2

# Store the result
tl.store(output_ptr + offsets, result, mask=mask)

def add_3d_slices(input1, input2, output):
# Get tensor shapes
slice_z, slice_y, slice_x = input1.shape

# Compute strides
stride_z, stride_y, stride_x = input1.stride()

# Determine grid size
grid = (
triton.cdiv(slice_x, BLOCK_SIZE_X),
triton.cdiv(slice_y, BLOCK_SIZE_Y),
triton.cdiv(slice_z, BLOCK_SIZE_Z)
)

# Launch kernel
add_3d_slices_kernel[grid](
input1, input2, output,
stride_x, stride_y, stride_z,
slice_x, slice_y, slice_z,
BLOCK_SIZE_X=BLOCK_SIZE_X,
BLOCK_SIZE_Y=BLOCK_SIZE_Y,
BLOCK_SIZE_Z=BLOCK_SIZE_Z
)

if __name__ == "__main__":
# Set random seed for reproducibility
torch.manual_seed(0)

# Create example input tensor
input1 = torch.randn(16, 16, 32, device='cpu')
input2 = torch.randn(16, 16, 32, device='cpu')
output = torch.empty_like(input1)

# Call the kernel
add_3d_slices(input1, input2, output)
triton_viz.launch()

# Verify the result
expected_output = input1 + input2
assert torch.allclose(output, expected_output), "Kernel output does not match expected result"





7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="triton-viz",
version="0.1",
version="0.2",
packages=find_packages(),
description="A visualization tool for Triton",
author="Deep Learning Profiling Tools Team",
Expand All @@ -11,10 +11,11 @@
install_requires=[
"setuptools",
"triton",
"gradio",
"chalk-diagrams @ git+https://github.com/chalk-diagrams/chalk.git",
"flask",
"pyarrow",
"pre-commit",
"pytest",
"flask_cloudflared",
"requests",
],
)
4 changes: 2 additions & 2 deletions triton_viz/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .trace import trace, dump, sample
from .draw import collect_grid, draw_record
from .draw import collect_grid
from .interface import launch

__all__ = ["trace", "launch", "dump", "sample", "collect_grid", "draw_record"]
__all__ = ["trace", "launch", "dump", "sample", "collect_grid"]
14 changes: 13 additions & 1 deletion triton_viz/data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from dataclasses import dataclass, field
from typing import List, Tuple, Any
from typing import List, Tuple, Any, Dict
import traceback
import numpy.typing as npt
import numpy as np

import torch


@dataclass
class Op:
Expand Down Expand Up @@ -73,6 +75,15 @@ class Dot(Op):
input_shape: Tuple
other_shape: Tuple
output_shape: Tuple
input_data: List[List[float]]
other_data: List[List[float]]
intermediate_results: Dict[Tuple[int, int], float] = field(
default_factory=dict
) # Only storing the result now

def update_intermediate(self, row: int, col: int, result: float):
# Store only the result as a float
self.intermediate_results[(row, col)] = result


@dataclass
Expand All @@ -91,6 +102,7 @@ class Tensor:
stride: Tuple
shape: Tuple
element_size: int
data: torch.Tensor


@dataclass
Expand Down
Loading