Skip to content

Commit

Permalink
Initial changes for updated triton-viz
Browse files Browse the repository at this point in the history
  • Loading branch information
danikhan632 committed Sep 25, 2024
1 parent 434fa20 commit 025a4f2
Show file tree
Hide file tree
Showing 20 changed files with 2,761 additions and 530 deletions.
99 changes: 99 additions & 0 deletions examples/3dims.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch
import triton
import triton.language as tl



import triton_viz
BLOCK_SIZE_X = 32
BLOCK_SIZE_Y = 8
BLOCK_SIZE_Z = 4

@triton_viz.trace
@triton.jit
def add_3d_slices_kernel(
input_ptr1, input_ptr2, output_ptr,
stride_x, stride_y, stride_z,
slice_x, slice_y, slice_z,
BLOCK_SIZE_X: tl.constexpr, BLOCK_SIZE_Y: tl.constexpr, BLOCK_SIZE_Z: tl.constexpr
):
# Compute the 3D position in the output tensor
pid_x = tl.program_id(0)
pid_y = tl.program_id(1)
pid_z = tl.program_id(2)

# Compute the starting position for this block
x_start = pid_x * BLOCK_SIZE_X
y_start = pid_y * BLOCK_SIZE_Y
z_start = pid_z * BLOCK_SIZE_Z

# Compute offsets within the block
x_offsets = x_start + tl.arange(0, BLOCK_SIZE_X)
y_offsets = y_start + tl.arange(0, BLOCK_SIZE_Y)
z_offsets = z_start + tl.arange(0, BLOCK_SIZE_Z)

# Create a mask to handle boundary conditions
mask = (x_offsets < slice_x) & (y_offsets < slice_y)[:, None] & (z_offsets < slice_z)[:, None, None]

# Compute the input and output offsets
offsets = (
z_offsets[:, None, None] * stride_z +
y_offsets[:, None] * stride_y +
x_offsets * stride_x
)

# Load input slices
slice1 = tl.load(input_ptr1 + offsets, mask=mask)
slice2 = tl.load(input_ptr2 + offsets, mask=mask)

# Perform addition
result = slice1 + slice2

# Store the result
tl.store(output_ptr + offsets, result, mask=mask)

def add_3d_slices(input1, input2, output):
# Get tensor shapes
slice_z, slice_y, slice_x = input1.shape

# Compute strides
stride_z, stride_y, stride_x = input1.stride()

# Determine grid size
grid = (
triton.cdiv(slice_x, BLOCK_SIZE_X),
triton.cdiv(slice_y, BLOCK_SIZE_Y),
triton.cdiv(slice_z, BLOCK_SIZE_Z)
)

# Launch kernel
add_3d_slices_kernel[grid](
input1, input2, output,
stride_x, stride_y, stride_z,
slice_x, slice_y, slice_z,
BLOCK_SIZE_X=BLOCK_SIZE_X,
BLOCK_SIZE_Y=BLOCK_SIZE_Y,
BLOCK_SIZE_Z=BLOCK_SIZE_Z
)

if __name__ == "__main__":
# Set random seed for reproducibility
torch.manual_seed(0)

# Create example input tensor
input1 = torch.randn(16, 16, 32, device='cpu')
input2 = torch.randn(16, 16, 32, device='cpu')
output = torch.empty_like(input1)

# Call the kernel
add_3d_slices(input1, input2, output)
triton_viz.launch()

# Verify the result
expected_output = input1 + input2
assert torch.allclose(output, expected_output), "Kernel output does not match expected result"





7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="triton-viz",
version="0.1",
version="0.2",
packages=find_packages(),
description="A visualization tool for Triton",
author="Deep Learning Profiling Tools Team",
Expand All @@ -11,10 +11,11 @@
install_requires=[
"setuptools",
"triton",
"gradio",
"chalk-diagrams @ git+https://github.com/chalk-diagrams/chalk.git",
"flask",
"pyarrow",
"pre-commit",
"pytest",
"flask_cloudflared",
"requests",
],
)
4 changes: 2 additions & 2 deletions triton_viz/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .trace import trace, dump, sample
from .draw import collect_grid, draw_record
from .draw import collect_grid
from .interface import launch

__all__ = ["trace", "launch", "dump", "sample", "collect_grid", "draw_record"]
__all__ = ["trace", "launch", "dump", "sample", "collect_grid"]
14 changes: 13 additions & 1 deletion triton_viz/data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from dataclasses import dataclass, field
from typing import List, Tuple, Any
from typing import List, Tuple, Any, Dict
import traceback
import numpy.typing as npt
import numpy as np

import torch


@dataclass
class Op:
Expand Down Expand Up @@ -73,6 +75,15 @@ class Dot(Op):
input_shape: Tuple
other_shape: Tuple
output_shape: Tuple
input_data: List[List[float]]
other_data: List[List[float]]
intermediate_results: Dict[Tuple[int, int], float] = field(
default_factory=dict
) # Only storing the result now

def update_intermediate(self, row: int, col: int, result: float):
# Store only the result as a float
self.intermediate_results[(row, col)] = result


@dataclass
Expand All @@ -91,6 +102,7 @@ class Tensor:
stride: Tuple
shape: Tuple
element_size: int
data: torch.Tensor


@dataclass
Expand Down
Loading

0 comments on commit 025a4f2

Please sign in to comment.