Skip to content

Commit

Permalink
Initial changes for updated triton-viz
Browse files Browse the repository at this point in the history
  • Loading branch information
danikhan632 committed Aug 7, 2024
1 parent 589b115 commit fe9de84
Show file tree
Hide file tree
Showing 15 changed files with 1,465 additions and 527 deletions.
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="triton-viz",
version="0.1",
version="0.2",
packages=find_packages(),
description="A visualization tool for Triton",
author="Deep Learning Profiling Tools Team",
Expand All @@ -11,8 +11,7 @@
install_requires=[
"setuptools",
"triton",
"gradio",
"chalk-diagrams @ git+https://github.com/chalk-diagrams/chalk.git",
"flask",
"pyarrow",
"pre-commit",
"pytest",
Expand Down
4 changes: 2 additions & 2 deletions triton_viz/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .trace import trace, dump, sample
from .draw import collect_grid, draw_record
from .draw import collect_grid
from .interface import launch

__all__ = ["trace", "launch", "dump", "sample", "collect_grid", "draw_record"]
__all__ = ["trace", "launch", "dump", "sample", "collect_grid"]
15 changes: 13 additions & 2 deletions triton_viz/data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import dataclass, field
from typing import List, Tuple, Any
from typing import List, Tuple, Any, Dict
import traceback
import numpy.typing as npt
import numpy as np


import torch
@dataclass
class Op:
call_path: List[traceback.StackSummary] = field(init=False, default_factory=list)
Expand Down Expand Up @@ -68,11 +68,21 @@ class ExpandDims(Op):
output_shape: Tuple



@dataclass
class Dot(Op):
input_shape: Tuple
other_shape: Tuple
output_shape: Tuple
input_data: List[List[float]]
other_data: List[List[float]]
intermediate_results: Dict[Tuple[int, int], float] = field(default_factory=dict) # Only storing the result now

def update_intermediate(self, row: int, col: int, result: float):
# Store only the result as a float
self.intermediate_results[(row, col)] = result




@dataclass
Expand All @@ -91,6 +101,7 @@ class Tensor:
stride: Tuple
shape: Tuple
element_size: int
data: torch.Tensor


@dataclass
Expand Down
Loading

0 comments on commit fe9de84

Please sign in to comment.