From fe9de84b1c5d75c53e9334bf252bfbffc8021d7a Mon Sep 17 00:00:00 2001 From: Daniyal Khan Date: Wed, 7 Aug 2024 09:36:37 -0400 Subject: [PATCH] Initial changes for updated triton-viz --- setup.py | 5 +- triton_viz/__init__.py | 4 +- triton_viz/data.py | 15 +- triton_viz/draw.py | 539 +++++++---------------------- triton_viz/interface.py | 120 ++----- triton_viz/interpreter.py | 55 ++- triton_viz/static/gridblock.js | 288 +++++++++++++++ triton_viz/static/load.js | 172 +++++++++ triton_viz/static/matmul.js | 209 +++++++++++ triton_viz/static/store.js | 171 +++++++++ triton_viz/static/three.js | 0 triton_viz/static/visualization.js | 357 +++++++++++++++++++ triton_viz/templates/index.html | 47 +++ triton_viz/tooltip.py | 5 +- triton_viz/trace.py | 5 + 15 files changed, 1465 insertions(+), 527 deletions(-) create mode 100644 triton_viz/static/gridblock.js create mode 100644 triton_viz/static/load.js create mode 100644 triton_viz/static/matmul.js create mode 100644 triton_viz/static/store.js create mode 100644 triton_viz/static/three.js create mode 100644 triton_viz/static/visualization.js create mode 100644 triton_viz/templates/index.html diff --git a/setup.py b/setup.py index 6898fed..264defc 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="triton-viz", - version="0.1", + version="0.2", packages=find_packages(), description="A visualization tool for Triton", author="Deep Learning Profiling Tools Team", @@ -11,8 +11,7 @@ install_requires=[ "setuptools", "triton", - "gradio", - "chalk-diagrams @ git+https://github.com/chalk-diagrams/chalk.git", + "flask", "pyarrow", "pre-commit", "pytest", diff --git a/triton_viz/__init__.py b/triton_viz/__init__.py index 737dd36..7c5bab8 100644 --- a/triton_viz/__init__.py +++ b/triton_viz/__init__.py @@ -1,5 +1,5 @@ from .trace import trace, dump, sample -from .draw import collect_grid, draw_record +from .draw import collect_grid from .interface import launch -__all__ = ["trace", "launch", "dump", "sample", "collect_grid", "draw_record"] +__all__ = ["trace", "launch", "dump", "sample", "collect_grid"] diff --git a/triton_viz/data.py b/triton_viz/data.py index 819a398..1d411bb 100644 --- a/triton_viz/data.py +++ b/triton_viz/data.py @@ -1,10 +1,10 @@ from dataclasses import dataclass, field -from typing import List, Tuple, Any +from typing import List, Tuple, Any, Dict import traceback import numpy.typing as npt import numpy as np - +import torch @dataclass class Op: call_path: List[traceback.StackSummary] = field(init=False, default_factory=list) @@ -68,11 +68,21 @@ class ExpandDims(Op): output_shape: Tuple + @dataclass class Dot(Op): input_shape: Tuple other_shape: Tuple output_shape: Tuple + input_data: List[List[float]] + other_data: List[List[float]] + intermediate_results: Dict[Tuple[int, int], float] = field(default_factory=dict) # Only storing the result now + + def update_intermediate(self, row: int, col: int, result: float): + # Store only the result as a float + self.intermediate_results[(row, col)] = result + + @dataclass @@ -91,6 +101,7 @@ class Tensor: stride: Tuple shape: Tuple element_size: int + data: torch.Tensor @dataclass diff --git a/triton_viz/draw.py b/triton_viz/draw.py index fda62c0..9817d8d 100644 --- a/triton_viz/draw.py +++ b/triton_viz/draw.py @@ -1,77 +1,60 @@ -from colour import Color from triton_viz.data import ( Tensor, Grid, Store, Load, - Op, - MakeRange, - Reduce, - Dot, + Dot ) +from .trace import Trace from .interpreter import record_builder import numpy as np -import numpy.typing as npt -import planar -import math -import chalk -from typing import Tuple, Union, Optional, List, Dict -from chalk import Diagram, rectangle, text, hcat, vcat, empty, Path, Trail, V2, concat -from dataclasses import dataclass -from numpy.typing import ArrayLike -import sys +import torch +from typing import Dict, Tuple, List +import ctypes -sys.setrecursionlimit(100000) +def get_tensor_slice_coordinates(x: Load, tensor_table: Dict[int, Tuple[Tensor, int]]) -> List[Tuple[int, int]]: + + if x.ptr not in tensor_table: + raise KeyError(f"Tensor with ptr {x.ptr} not found in tensor_table") -planar.EPSILON = 0.0 -chalk.set_svg_draw_height(500) -BG = Color("white") -WHITE = Color("white") -DEFAULT = Color("grey") -BLACK = Color("black") -GREY = Color("grey") -palette = [ - "#f29f05", - "#f25c05", - "#d6568c", - "#4d8584", - "#a62f03", - "#400d01", - "#274001", - "#828a00", -] -ACTIVE = [Color(p) for p in palette] + tensor, _ = tensor_table[x.ptr] + global_shape = tensor.shape -MRATIO = 1 / 3 + if len(global_shape) != 2 or len(x.shape) != 2: + raise ValueError("This function only supports 2D tensors") -# Generic render helpers + # Extract the row and column offsets + row_offsets = x.offsets[:, 0] + col_offsets = x.offsets[0, :] + # Find the start and end coordinates + start_y, end_y = np.min(row_offsets), np.max(row_offsets) + x.shape[0] + start_x, end_x = np.min(col_offsets), np.max(col_offsets) + x.shape[1] -def box(d: Diagram, width: float, height: float, outer=0.2) -> Diagram: - "Put diagram in a box of shape height, width" - h, w = d.get_envelope().height, d.get_envelope().width - m = max(h, w) - back = rectangle(outer + m, outer + m).line_width(0).fill_color(BG).center_xy() - d = (back + d.center_xy()).with_envelope(back) - return d.scale_x(width / m).scale_y(height / m) -def reshape(d: Diagram) -> Diagram: - "Use log-scale if ratio is too sharp" - h, w = d.get_envelope().height, d.get_envelope().width - if (h / w > MRATIO) or (w / h > MRATIO): - d = d.scale_y(math.log(h + 1, 2) / h).scale_x(math.log(w + 1, 2) / w) - return d + # Ensure coordinates are within the global tensor bounds + start_y = max(0, min(start_y, global_shape[0] - 1)) + start_x = max(0, min(start_x, global_shape[1] - 1)) + end_y = max(start_y + 1, min(end_y, global_shape[0])) + end_x = max(start_x + 1, min(end_x, global_shape[1])) + return [ + (int(start_y), int(start_x)), + (int(start_y), int(end_x) - 1), + (int(end_y) - 1, int(start_x)), + (int(end_y) - 1, int(end_x) - 1) + ] + + def collect_grid(): for launch in record_builder.launches[-1:]: records, tensor_table, failures = collect_launch(launch) return records, tensor_table, failures - def collect_launch(launch): tensor_table = {} for i, t in enumerate(launch.tensors): @@ -96,375 +79,89 @@ def collect_launch(launch): return all_grids, tensor_table, failures -def draw_record(program_record, tensor_table, output): - return draw_launch(program_record, tensor_table, output) - - -def draw_launch(program_records, tensor_table, base) -> Diagram: - def draw(x): - "Dispatch" - if isinstance(x, Tensor): - return draw_tensor(x) - if isinstance(x, Grid): - return draw_grid(x) - if isinstance(x, Store): - return draw_store(x, tensor_table) - if isinstance(x, Load): - return draw_load(x, tensor_table) - if isinstance(x, Op): - return draw_op(x) - if isinstance(x, MakeRange): - return draw_make_range(x) - if isinstance(x, Reduce): - return draw_reduce(x) - if isinstance(x, Dot): - return None # draw_dot(x) - - def draw_record(x): - "Render one record" - y = draw(x) - if y is None: - return empty() - - return (chalk.vstrut(0.2) / y).center_xy() - - records = [] - for r in program_records: - dr = draw_record(r) - # env = dr.get_envelope() - # dr = dr.center_xy().with_envelope(rectangle(env.width, env.height).center_xy()) - records.append(dr) - - dr = vcat(records) - dr = dr.center_xy() - env = dr.get_envelope() - dr = rectangle(env.width + 1, env.height + 1).fill_color(BG).center_xy() + dr - dr.render_svg(base, 2500) - return env.width, env.height - - -def delinearize(shape: Tuple, x: npt.NDArray, dtype, mask) -> List[npt.NDArray]: - if len(shape) == 1: - shape = (1, 1, shape[0]) - x = x.copy() // (dtype.element_ty.primitive_bitwidth // 8) - vals = [] - for s in list(reversed(shape[1:])) + [10000]: - vals.append(((x % s) * mask - (1 - mask)).ravel()) - x = x // s - return vals - - -trail = Trail.from_offsets([V2(0, 1), V2(1, 0), V2(0, -1), V2(-1, 0)], closed=True) - - -def cover( - shape: Tuple, dtype, load: Tensor, mask: npt.NDArray, color: Color -) -> Diagram: - shape = make_3d(shape) - "Draw the values from load on top of the loading tensor" - x, y, z = delinearize(shape, load, dtype, mask) - return draw_tensor_3d(shape, z, y, x, color) - - -def pair_draw(x: Diagram, y: Diagram, command: str) -> Diagram: - "Draw two diagrams next to each other with a command in the middle." - return hcat([box(x, 3, 2.5), box(y, 3, 2.5)], 1).center_xy() + text( - command, 0.2 - ).fill_color(BLACK).line_width(0).translate(0, -1) - - -# Individual renderers - - -def draw_tensor(x: Tensor) -> Optional[Diagram]: - return None - - -def draw_grid(x: Grid) -> Optional[Diagram]: - return None - - -def draw_make_range(x: MakeRange) -> Optional[Diagram]: - return None - - -def draw_reduce(x: Reduce) -> Optional[Diagram]: - color = ACTIVE[0] - inp = draw_tensor_3d(make_3d(x.input_shape), None, None, None, color) - if x.index == 0 and len(x.input_shape) == 2: - inp = hcat( - [ - rectangle(0.1, inp.get_envelope().height) - .align_t() - .line_width(0) - .fill_color(BLACK), - inp, - ], - 0.5, - ) - else: - inp = vcat( - [ - rectangle(inp.get_envelope().width, 0.1) - .align_l() - .line_width(0) - .fill_color(BLACK), - inp, - ], - 0.5, - ) - out = draw_tensor_3d(x.output_shape, None, None, None, color) - return pair_draw(reshape(inp), reshape(out), x.op) - - -def draw_load(x, tensor_table) -> Optional[Diagram]: - inp, out = store_load(x, tensor_table) - out = reshape(out) - return pair_draw(inp, out, "load") - - -def draw_store(x, tensor_table) -> Optional[Diagram]: - inp, out = store_load(x, tensor_table) - out = reshape(out) - return pair_draw(out, inp, "store") - - -def make_3d(shape): - "Make a 3d shape" - if len(shape) == 1: - return (1, 1, shape[0]) - if len(shape) == 2: - return (1, shape[0], shape[1]) - return shape - - -def store_load( - x: Union[Store, Load], tensor_table: Dict[int, Tuple[Tensor, int]] -) -> Tuple[Diagram, Diagram]: - tensor, tensor_id = tensor_table[x.ptr] - # inp = base_tensor(tensor.shape, DEFAULT) - color = ACTIVE[tensor_id] - invalid = x.invalid_access_masks.any() - if invalid: - color = Color("red") - inp = cover(tensor.shape, tensor.dtype, x.original_offsets, x.original_masks, color) - inp = reshape(inp) - s = make_3d(x.original_offsets.shape) - a, b, c = x.original_masks.reshape(*s).nonzero() - out = draw_tensor_3d(s, a, b, c, color) - return inp, out - - -def draw_op(x: Op) -> Optional[Diagram]: - return None - - -def draw_dot(x: Dot) -> Optional[Diagram]: - if x.input_shape == (1,): - return None - inp = draw_tensor_3d(x.input_shape[0], None, None, None) - # inp = reshape(base_tensor(x.input_shape[0], color=ACTIVE)) - # inp = add_whiskers(inp, x.input_shape[0]) - inp2 = draw_tensor_3d(x.input_shape[1], None, None, None) - # inp2 = reshape(base_tensor(x.input_shape[0], color=ACTIVE)) - # inp2 = add_whiskers(inp2, x.input_shape[0]) - out = draw_tensor_3d(x.output_shape, None, None, None) - # out = reshape(base_tensor(x.output_shape, color=ACTIVE)) - # out = add_whiskers(out, x.output_shape) - return hcat( - [box(inp, 1.5, 2), box(inp2, 1.5, 2), box(out, 1.5, 2)], 1 - ).center_xy() + text("dot", 0.2).fill_color(BLACK).line_width(0).translate(0, -1) - - -# For 3d - - -def lookAt(eye: ArrayLike, center: ArrayLike, up: ArrayLike): - "Python version of the haskell lookAt function in linear.projections" - f = (center - eye) / np.linalg.norm(center - eye) - s = np.cross(f, up) / np.linalg.norm(np.cross(f, up)) - u = np.cross(s, f) - return np.array([[*s, 0], [*u, 0], [*-f, 0], [0, 0, 0, 1]]) - - -def scale3(x, y, z): - return np.array([[x, 0, 0, 0], [0, y, 0, 0], [0, 0, z, 0], [0, 0, 0, 1]]) - - -@dataclass -class D3: - x: float - y: float - z: float - - def to_np(self): - return np.array([self.x, self.y, self.z]) - - -V3 = D3 - - -def homogeneous(trails: List[List[D3]]): - "Convert list of directions to a np.array of homogeneous coordinates" - return np.array([[[*o.to_np(), 1] for o in offsets] for offsets in trails]) - - -def cube(): - "3 faces of a cube drawn as offsets from the origin." - return homogeneous( - [ - [D3(*v) for v in offset] - for offset in [ - [(1, 0, 0), (0, 1, 0), (-1, 0, 0), (0, -1, 0)], - [(1, 0, 0), (0, 0, 1), (-1, 0, 0), (0, 0, -1)], - [(0, 0, 1), (0, 1, 0), (0, 0, -1), (0, -1, 0)], - ] - ] - ) - - -def to_trail(trail: ArrayLike, locations: ArrayLike): - return [ - ( - Path( - [ - Trail.from_offsets([V2(*v[:2]) for v in trail]) - .close() - .at(V2(*loc[:2])) - ] - ), - loc[2], - ) - for loc in locations - ] - - -def project(projection, shape3, positions): - p = homogeneous([positions for _ in range(shape3.shape[0])]) - locations = p @ projection.T - trails = shape3 @ projection.T - return [out for t, loc in zip(trails, locations) for out in to_trail(t, loc)] - - -def draw_tensor_3d(shape, a, b, c, color=WHITE): - shape = make_3d(shape) - - # Big Cube - s = scale3(*shape) - big_cube = cube() @ s.T - back = scale3(0, shape[1], shape[2]) - back_cube = cube() @ back.T - - # Isometric projection of tensor - projection = lookAt( - V3(-1.0, -0.3, -0.15).to_np(), - V3(0, 0, 0).to_np(), - V3(0, 1, 0).to_np(), - ) - outer = project(projection, big_cube, [V3(0, 0, 0)]) - outer2 = project(projection, back_cube, [V3(shape[0], 0, 0)]) - d = ( - concat([p.stroke().fill_color(GREY).fill_opacity(0.1) for p, _ in outer2]) - .line_width(0.005) - .line_color(GREY) - ) - d += ( - concat([p.stroke().fill_color(GREY).fill_opacity(0.05) for p, _ in outer]) - .line_width(0.01) - .line_color(BLACK) - ) - if a is not None: - out = group(a, b, c) - d2 = [ - (b, loc) - for i in range(len(out)) - for b, loc in make_cube(projection, out[i][0], out[i][1], color) - ] - d2.sort(key=lambda x: x[1], reverse=True) - d2 = concat([b.with_envelope(empty()) for b, _ in d2]) - d = d2.with_envelope(d) + d - return d - - -def lines(s): - "Draw lines to mimic a cube of cubes" - bs = [np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])] - return [ - homogeneous( - [ - [D3(*p) for _ in range(s[i]) for p in [a / s[i], b, -b]] - for b in bs - if not np.all(a == b) - ] - ) - for i, a in enumerate(bs) - ] - - -def make_cube(projection, start, end, color): - "Draws a cube from start position to end position." - start = np.array(start).astype(int) - end = np.array(end).astype(int) - s2 = end - start + 1 - s = scale3(*s2) - small_cube = cube() @ s.T - loc = [ - project(projection, l2 @ s.T, [V3(*start)]) - for l2 in lines(s2) - if l2.shape[1] > 0 - ] - outer2 = project(projection, small_cube, [V3(*start)]) - ls = loc - box = [ - (p.stroke().fill_color(color).fill_opacity(0.4).line_width(0), loc) - for p, loc in outer2 - ] - line = [ - (p.stroke().line_width(0.001).line_color(BLACK), l_) - for loc in ls - for p, l_ in loc - ] - return [(b, loc) for b, loc in box + line] - - -def group( - x: ArrayLike, y: ArrayLike, z: ArrayLike -) -> List[Tuple[Tuple[float, float, float], Tuple[float, float, float]]]: - "Groups together cubes into bigger cubes" - x = list(zip(zip(x, y, z), zip(x, y, z))) - x = [(a, b) for a, b in x if not (a[0] == -1 and a[1] == -1 and a[2] == -1)] - - start = x - - def remove_dups(ls): - "Remove duplicates" - out = [] - for y in ls: - if not out or y != out[-1]: - out.append(y) - return out - - for j in range(2, -1, -1): - x = remove_dups(start) - start = [] - while True: - if len(x) <= 1: - break - _, _, rest = x[0], x[1], x[2:] - m = 0 - for k in range(2): - a = x[0][k] - b = x[1][k] - if ( - (k == 0 or a[j % 3] == b[j % 3] - 1) - and a[(j + 1) % 3] == b[(j + 1) % 3] - and a[(j + 2) % 3] == b[(j + 2) % 3] - ): - m += 1 - if m == 2: - x = [[x[0][0], x[1][1]]] + rest - else: - start.append(x[0]) - x = [x[1]] + rest - start += x - return start +def get_tensor_data(tensor: Tensor): + # Create a NumPy array from the memory address + np_array = np.frombuffer( + (ctypes.c_float * np.prod(tensor.shape)).from_address(tensor.ptr), + dtype=np.float32 + ).reshape(tensor.shape) + + # Convert NumPy array to PyTorch tensor + return torch.from_numpy(np_array) + + + + +def prepare_visualization_data(program_records, tensor_table, getValues=False): + """Prepare visualization data for the frontend.""" + visualization_data = [] + + for record in program_records: + if isinstance(record, Dot): + intermediate_results = {} + if getValues: + intermediate_results = { + f"{row},{col}": { + 'result': result + } + for (row, col), result in record.intermediate_results.items() + } + + visualization_data.append({ + 'type': 'Dot', + 'input_shape': record.input_shape, + 'input_data': record.input_data, + 'other_data': record.other_data, + 'other_shape': record.other_shape, + 'output_shape': record.output_shape, + 'intermediate_results': intermediate_results + }) + + elif isinstance(record, (Load, Store)): + global_tensor, _ = tensor_table[record.ptr] + + slice_coords = get_tensor_slice_coordinates(record, tensor_table) + slice_tensor = torch.zeros(record.shape) + + global_values = [] + if getValues: + global_values = global_tensor.data.tolist() + + visualization_data.append({ + 'type': 'Load' if isinstance(record, Load) else 'Store', + 'global_values': global_values, + 'global_shape': global_tensor.shape, + 'slice_shape': slice_tensor.shape, + 'slice_coords': slice_coords + }) + + return visualization_data + +def get_visualization_data(getValues=False): + """Return the visualization data in a format suitable for JSON serialization.""" + records, tensor_table, failures = collect_grid() + visualization_data = {} + for grid_idx, program_records in records.items(): + visualization_data[str(grid_idx)] = prepare_visualization_data(program_records, tensor_table) + + # Get the kernel source code + kernel_src = "" + if record_builder.launches and isinstance(record_builder.launches[0], Trace): + kernel_src = record_builder.launches[0].get_src() + + return { + "visualization_data": visualization_data, + "failures": failures, + "kernel_src": kernel_src + } + +def serialize_for_json(obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, torch.Tensor): + return obj.cpu().numpy().tolist() + raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable') \ No newline at end of file diff --git a/triton_viz/interface.py b/triton_viz/interface.py index 472f01d..36c9183 100644 --- a/triton_viz/interface.py +++ b/triton_viz/interface.py @@ -1,99 +1,41 @@ -import gradio as gr -import triton_viz -import tempfile +import threading +import webbrowser +import json +from flask import Flask, render_template, jsonify from .analysis import analyze_records -from .tooltip import create_tooltip +from .draw import collect_grid,get_visualization_data +from .tooltip import get_tooltip_data import pandas as pd +import os -def launch(share=True): - cache = {} - analysis_data = analyze_records() - program_records, tt, failures = triton_viz.collect_grid() - m = [0, 0, 0] - size = [0, 0] - for k in program_records.keys(): - m[0] = max(k[0] + 1, m[0]) - m[1] = max(k[1] + 1, m[1]) - m[2] = max(k[2] + 1, m[2]) - w, h = triton_viz.draw_record(program_records[(0, 0, 0)], tt, "tmp.svg") - size[0] = w - size[1] = h - height = 600 * size[1] / size[0] - with gr.Blocks( - css=".gradio-container button {overflow: auto} img.with-caption {height: %fpx !important; } .thumbnails { display: none; } " - % height - ) as demo: - with gr.Row(): - with gr.Column(scale=3, min_width=500): - img = gr.Gallery( - height=500, - min_width=500, - show_label=False, - selected_index=0, - preview=True, - object_fit="cover", - ) - with gr.Column(scale=1): - s1 = gr.Slider(0, m[0] - 1, value=0, step=1, label="Program Id 0") - s2 = gr.Slider(0, m[1] - 1, value=0, step=1, label="Program Id 1") - s3 = gr.Slider(0, m[2] - 1, value=0, step=1, label="Program Id 2") - b1 = gr.Button("Precompute") - gr.Markdown("## Analysis") - df = pd.DataFrame(analysis_data, columns=["Metric", "Value"]) - analysis_with_tooltip = create_tooltip(df) - gr.HTML(analysis_with_tooltip) - if failures: - gr.Markdown( - show_label=False, - value="## Invalid memory access in " - + "\n * " - + "\n* ".join(list(map(str, failures.keys()))), - ) +app = Flask(__name__, + template_folder=os.path.join(os.path.dirname(__file__), 'templates'), + static_folder=os.path.join(os.path.dirname(__file__), 'static')) - def cache_block(idx): - name = tempfile.NamedTemporaryFile(suffix=".svg") - w, h = triton_viz.draw_record(program_records[idx], tt, name.name) - size[0] = w - size[1] = h - cache[idx] = (name, len(cache)) +@app.route('/') +def index(): + return render_template('index.html') - def update(inp): - a = inp[s1] - b = inp[s2] - c = inp[s3] - idx = (a, b, c) - if idx not in cache: - cache_block(idx) - return gr.Gallery( - value=[(cache[k][0].name, str(k)) for k in cache.keys()], - selected_index=cache[idx][1], - height=700, - ), gr.Slider() - # * size[1]/size[0] - return gr.Gallery(selected_index=cache[idx][1]), gr.Slider() +@app.route('/api/data') +def get_data(): + analysis_data = analyze_records() - def precompute(inp): - a = inp[s1] - b = inp[s2] - c = inp[s3] - idx = (a, b, c) - for i in range(m[0]): - for j in range(m[1]): - for k in range(m[2]): - if (i, j, k) not in cache: - cache_block((i, j, k)) - return gr.Gallery( - value=[(cache[k][0].name, str(k)) for k in cache.keys()], - selected_index=cache[idx][1], - ) + df = pd.DataFrame(analysis_data, columns=["Metric", "Value"]) + analysis_with_tooltip = get_tooltip_data(df) - s1.change(update, inputs={s1, s2, s3}, outputs=[img, b1], show_progress=False) - s2.change(update, inputs={s1, s2, s3}, outputs=[img, b1], show_progress=False) - s3.change(update, inputs={s1, s2, s3}, outputs=[img, b1], show_progress=False) - b1.click(precompute, inputs={s1, s2, s3}, outputs=img, show_progress=True) - demo.load(update, inputs={s1, s2, s3}, outputs=[img, b1]) + return jsonify({"ops":get_visualization_data()}) - demo.launch(share=share, debug=False, height=800, quiet=True, show_api=False) - return failures +def run_flask(): + app.run(port=5000,host='127.0.0.1') + +def launch(share=True): + print("Launching Triton viz tool") + flask_thread = threading.Thread(target=run_flask) + flask_thread.start() + return flask_thread + +def stop_server(flask_thread): + # Implement a way to stop the Flask server + pass \ No newline at end of file diff --git a/triton_viz/interpreter.py b/triton_viz/interpreter.py index f447bbf..aea1609 100644 --- a/triton_viz/interpreter.py +++ b/triton_viz/interpreter.py @@ -44,6 +44,20 @@ def _unpatch_lang(): class RecordBuilder: + + def to_dict(self): + """Convert the recorded data to a dictionary format for JSON serialization.""" + return { + "launches": [ + { + "grid": launch.grid, + "tensors": [tensor.__dict__ for tensor in launch.tensors], + "records": [record.__dict__ for record in launch.records] + } + for launch in self._launches + ] + } + def __init__(self) -> None: self.reset() @@ -152,6 +166,7 @@ def _grid_executor_call(self, *args_dev, **kwargs): else: ret = _implicit_cvt(arg) if hasattr(arg, "data_ptr"): + assert _check_storage_contiguous( arg ), "triton-viz only supports contiguouly stored tensors for now" @@ -162,6 +177,7 @@ def _grid_executor_call(self, *args_dev, **kwargs): arg.stride(), arg.shape, arg.element_size(), + arg ) ) call_args[name] = ret @@ -291,18 +307,33 @@ def wrapper(lhs, rhs, op): def _create_dot(fn): @wraps(fn) - def wrapper(a, b, d, allow_tf32, maxNumImpreciseAcc): - ret = fn(a, b, d, allow_tf32, maxNumImpreciseAcc) + def wrapper(a, b, c, allow_tf32, max_num_imprecise_acc): dot_record = Dot( - input_shape=(a.data.shape, b.data.shape), - other_shape=d.data.shape, - output_shape=ret.data.shape, + input_shape=a.data.shape, + other_shape=b.data.shape, + output_shape=c.data.shape, + input_data=a.data.tolist(), + other_data=b.data.tolist() ) record_builder.add_record(dot_record) - return ret - return wrapper + def capture_intermediate(row, col, result): + dot_record.update_intermediate(row, col, float(result)) + # Modify the original function to call capture_intermediate at each step + def modified_fn(a, b, c, allow_tf32, max_num_imprecise_acc): + for i in range(a.data.shape[0]): + for j in range(b.data.shape[1]): + A_row = a.data[i, :] + B_column = b.data[:, j] + result = np.dot(A_row, B_column) + capture_intermediate(i, j, result) + c.data[i, j] = result + + modified_fn(a, b, c, allow_tf32, max_num_imprecise_acc) + return c + + return wrapper def _create_expand_dims(fn): @wraps(fn) @@ -344,7 +375,7 @@ def patch(): old_grid_executor_call = GridExecutor.__call__ old_jit_function_call = JITFunction.__call__ # XXX(Keren): Temporarily disable rewriting of AST - old_rewrite_ast = InterpretedFunction._rewrite_ast + # old_rewrite_ast = InterpretedFunction._rewrite_ast old_create_make_range = interpreter_builder.create_make_range old_create_masked_load = interpreter_builder.create_masked_load old_create_expand_dims = interpreter_builder.create_expand_dims @@ -373,10 +404,16 @@ def patch(): finally: GridExecutor.__call__ = old_grid_executor_call JITFunction.__call__ = old_jit_function_call - InterpretedFunction._rewrite_ast = old_rewrite_ast + # InterpretedFunction._rewrite_ast = old_rewrite_ast interpreter_builder.create_make_range = old_create_make_range interpreter_builder.create_masked_load = old_create_masked_load interpreter_builder.create_expand_dims = old_create_expand_dims interpreter_builder.binary_op = old_binary_op interpreter_builder.create_dot = old_create_dot interpreter_builder.create_masked_store = old_create_masked_store + + + +def get_recorded_data(): + """Return the recorded data in a format suitable for JSON serialization.""" + return record_builder.to_dict() \ No newline at end of file diff --git a/triton_viz/static/gridblock.js b/triton_viz/static/gridblock.js new file mode 100644 index 0000000..8d00fc9 --- /dev/null +++ b/triton_viz/static/gridblock.js @@ -0,0 +1,288 @@ +import {createMatMulVisualization} from './matmul.js' +import {createLoadVisualization} from './load.js'; +import {createStoreVisualization} from './store.js'; + +export class GridBlock { + constructor(x, y, width, height, gridX, gridY, gridZ, blockData) { + this.rect = { x, y, width, height }; + this.gridPosition = { x: gridX, y: gridY, z: gridZ }; + this.blockData = blockData; + this.isHovered = false; + this.visualizationContainer = null; + this.cleanupFunction = null; + + // Three.js properties + this.scene = null; + this.camera = null; + this.renderer = null; + this.matMulVisualization = null; + + // 2D canvas for additional information + this.canvas2D = null; + this.ctx2D = null; + + // Animation frame ID + this.animationFrameId = null; + } + + draw(ctx) { + // Draw background + ctx.fillStyle = this.isHovered ? '#4a4a4a' : '#323232'; + ctx.fillRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + + // Draw border + ctx.strokeStyle = '#000000'; + ctx.lineWidth = 1; + ctx.strokeRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + + // Draw grid position + ctx.fillStyle = '#c8c8c8'; + ctx.font = '12px Arial'; + ctx.textAlign = 'center'; + ctx.textBaseline = 'top'; + const posText = `${this.gridPosition.x},${this.gridPosition.y},${this.gridPosition.z}`; + ctx.fillText(posText, this.rect.x + this.rect.width / 2, this.rect.y + 2); + + // Draw operation types + ctx.font = '10px Arial'; + ctx.textAlign = 'left'; + ctx.textBaseline = 'top'; + this.blockData.forEach((op, index) => { + ctx.fillText(op.type, this.rect.x + 5, this.rect.y + 20 + index * 15); + }); + } + + isPointInside(x, y) { + return x >= this.rect.x && x <= this.rect.x + this.rect.width && + y >= this.rect.y && y <= this.rect.y + this.rect.height; + } + + handleMouseMove(x, y) { + const wasHovered = this.isHovered; + this.isHovered = this.isPointInside(x, y); + return wasHovered !== this.isHovered; // Return true if hover state changed + } + + showDetailedView(containerElement) { + if (this.isDetailedViewVisible) return; + + this.visualizationContainer = containerElement; + this.visualizationContainer.innerHTML = ''; + this.visualizationContainer.style.pointerEvents = 'auto'; + + const wrapper = document.createElement('div'); + wrapper.style.padding = '10px'; + wrapper.style.backgroundColor = '#1e1e28'; + wrapper.style.color = '#ffffff'; + wrapper.style.height = '100%'; + wrapper.style.overflowY = 'auto'; + wrapper.style.display = 'flex'; + wrapper.style.flexDirection = 'column'; + wrapper.style.alignItems = 'stretch'; + wrapper.style.width = '100%'; + + const title = document.createElement('h2'); + title.textContent = `Operations at (${this.gridPosition.x}, ${this.gridPosition.y}, ${this.gridPosition.z})`; + title.style.textAlign = 'center'; + title.style.margin = '10px 0'; + wrapper.appendChild(title); + + const columnsContainer = document.createElement('div'); + columnsContainer.style.display = 'flex'; + columnsContainer.style.justifyContent = 'space-between'; + columnsContainer.style.gap = '10px'; + columnsContainer.style.flex = '1'; + + const leftColumn = document.createElement('div'); + leftColumn.style.flex = '1'; + leftColumn.style.display = 'flex'; + leftColumn.style.flexDirection = 'column'; + leftColumn.style.gap = '10px'; + + const rightColumn = document.createElement('div'); + rightColumn.style.flex = '1'; + rightColumn.style.display = 'flex'; + rightColumn.style.flexDirection = 'column'; + rightColumn.style.gap = '10px'; + + this.blockData.forEach((op, index) => { + const opContainer = document.createElement('div'); + opContainer.style.padding = '10px'; + opContainer.style.border = '1px solid #444'; + opContainer.style.borderRadius = '5px'; + opContainer.style.display = 'flex'; + opContainer.style.flexDirection = 'column'; + opContainer.style.flex = '1'; + + const opTitle = document.createElement('h3'); + opTitle.textContent = `Operation ${index + 1}: ${op.type}`; + opTitle.style.margin = '0 0 10px 0'; + opContainer.appendChild(opTitle); + + const visualizationDiv = document.createElement('div'); + visualizationDiv.style.flex = '1'; + visualizationDiv.style.minHeight = '300px'; + opContainer.appendChild(visualizationDiv); + + switch(op.type) { + case 'Dot': + createMatMulVisualization(visualizationDiv, op); + break; + case 'Load': + createLoadVisualization(visualizationDiv, op); + break; + case 'Store': + createStoreVisualization(visualizationDiv, op); + break; + default: + const unsupportedMsg = document.createElement('p'); + unsupportedMsg.textContent = `Unsupported operation type: ${op.type}`; + visualizationDiv.appendChild(unsupportedMsg); + } + + if (index % 2 === 0) { + leftColumn.appendChild(opContainer); + } else { + rightColumn.appendChild(opContainer); + } + }); + + columnsContainer.appendChild(leftColumn); + columnsContainer.appendChild(rightColumn); + wrapper.appendChild(columnsContainer); + + if (this.blockData.length === 0) { + const noDataMsg = document.createElement('p'); + noDataMsg.textContent = 'No operation data available'; + noDataMsg.style.textAlign = 'center'; + wrapper.appendChild(noDataMsg); + } + + const closeButton = document.createElement('button'); + closeButton.textContent = 'Close'; + closeButton.style.position = 'fixed'; + closeButton.style.top = '10px'; + closeButton.style.right = '10px'; + closeButton.addEventListener('click', () => this.hideDetailedView()); + + this.visualizationContainer.appendChild(wrapper); + this.visualizationContainer.appendChild(closeButton); + + this.visualizationContainer.tabIndex = 0; + this.visualizationContainer.focus(); + + this.isDetailedViewVisible = true; + this.cleanupFunction = this.hideDetailedView.bind(this); + + // Trigger a resize event to ensure visualizations render correctly + setTimeout(() => { + window.dispatchEvent(new Event('resize')); + }, 100); + } + + hideDetailedView() { + if (!this.isDetailedViewVisible) return; + + if (this.visualizationContainer) { + this.visualizationContainer.innerHTML = ''; // Clear the container + this.visualizationContainer.style.pointerEvents = 'none'; // Disable interaction + } + + this.isDetailedViewVisible = false; + this.cleanupFunction = null; + } + + + animate() { + this.animationFrameId = requestAnimationFrame(this.animate.bind(this)); + + if (this.matMulVisualization && this.matMulVisualization.update) { + this.matMulVisualization.update(); + } + + if (this.renderer && this.scene && this.camera) { + this.renderer.render(this.scene, this.camera); + } + + this.drawTensorView(this.ctx2D, this.canvas2D); + } + + onWindowResize() { + if (this.camera && this.renderer && this.visualizationContainer) { + const width = this.visualizationContainer.clientWidth; + const height = this.visualizationContainer.clientHeight; + this.camera.aspect = width / height; + this.camera.updateProjectionMatrix(); + this.renderer.setSize(width, height); + if (this.canvas2D) { + this.canvas2D.width = width; + this.canvas2D.height = height; + } + } + } + + drawTensorView(ctx, canvas) { + if (!ctx || !canvas) return; + + // Clear the 2D canvas + ctx.clearRect(0, 0, canvas.width, canvas.height); + + // Draw tensor view content + ctx.fillStyle = 'rgba(255, 255, 255, 0.7)'; + ctx.font = '24px Arial'; + ctx.textAlign = 'center'; + ctx.textBaseline = 'top'; + ctx.fillText('Matrix Multiplication Visualization', canvas.width / 2, 10); + + // Display block information + ctx.font = '18px Arial'; + ctx.fillText(`Block Position: (${this.gridPosition.x}, ${this.gridPosition.y}, ${this.gridPosition.z})`, canvas.width / 2, 40); + + const dotOperation = this.blockData.find(op => op.type === 'Dot'); + if (dotOperation) { + // Display matrix dimensions + ctx.font = '16px Arial'; + ctx.textAlign = 'left'; + ctx.fillText(`Matrix A: ${dotOperation.input_shape.join('x')}`, 10, 70); + ctx.fillText(`Matrix B: ${dotOperation.other_shape.join('x')}`, 10, 90); + ctx.fillText(`Result Matrix: ${dotOperation.output_shape.join('x')}`, 10, 110); + } + + // Add a back button + ctx.fillStyle = 'rgba(200, 200, 200, 0.8)'; + ctx.fillRect(10, 10, 60, 30); + ctx.fillStyle = 'black'; + ctx.font = '16px Arial'; + ctx.textAlign = 'center'; + ctx.fillText('Back', 40, 25); + } + + hideDetailedView() { + + if (this.matMulVisualization && this.matMulVisualization.cleanup) { + this.matMulVisualization.cleanup(); + } + if (this.renderer) { + this.renderer.dispose(); + } + if (this.animationFrameId) { + cancelAnimationFrame(this.animationFrameId); + } + if (this.visualizationContainer && this.visualizationContainer.parentNode) { + this.visualizationContainer.parentNode.removeChild(this.visualizationContainer); + } + window.removeEventListener('resize', this.onWindowResize.bind(this)); + + this.scene = null; + this.camera = null; + this.renderer = null; + this.matMulVisualization = null; + this.canvas2D = null; + this.ctx2D = null; + this.visualizationContainer = null; + this.animationFrameId = null; + this.cleanupFunction = null; + + + } +} \ No newline at end of file diff --git a/triton_viz/static/load.js b/triton_viz/static/load.js new file mode 100644 index 0000000..b3f1d34 --- /dev/null +++ b/triton_viz/static/load.js @@ -0,0 +1,172 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; + +export function createLoadVisualization(containerElement, op) { + const CUBE_SIZE = 0.2; + const GAP = 0.05; + + // Colors + const COLOR_GLOBAL = new THREE.Color(0.2, 0.2, 0.2); // Dark Gray + const COLOR_SLICE = new THREE.Color(0.0, 0.7, 1.0); // Cyan (starting color for global slice) + const COLOR_LEFT_SLICE = new THREE.Color(1.0, 0.0, 1.0); // Magenta (starting color for left slice) + const COLOR_LOADED = new THREE.Color(1.0, 0.8, 0.0); // Gold (final color for both slices) + const COLOR_BACKGROUND = new THREE.Color(0.0, 0.0, 0.0); // Black + const COLOR_EDGE = new THREE.Color(0.5, 0.5, 0.5); // Gray (for cube edges) + + // Scene setup + const scene = new THREE.Scene(); + scene.background = COLOR_BACKGROUND; + const camera = new THREE.PerspectiveCamera(45, containerElement.clientWidth / containerElement.clientHeight, 0.1, 1000); + camera.position.set(15, -15, 30); + camera.lookAt(8, -5, 0); + + const renderer = new THREE.WebGLRenderer({ antialias: true }); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + containerElement.appendChild(renderer.domElement); + + // Lighting + const ambientLight = new THREE.AmbientLight(0xffffff, 0.3); + scene.add(ambientLight); + const directionalLight = new THREE.DirectionalLight(0xffffff, 0.7); + directionalLight.position.set(10, 10, 10); + scene.add(directionalLight); + + // Create cube geometry and materials + const cubeGeometry = new THREE.BoxGeometry(CUBE_SIZE, CUBE_SIZE, CUBE_SIZE); + const edgesGeometry = new THREE.EdgesGeometry(cubeGeometry); + const lineMaterial = new THREE.LineBasicMaterial({ color: COLOR_EDGE }); + + function createCube(color) { + const cubeMaterial = new THREE.MeshPhongMaterial({ color: color }); + const cube = new THREE.Mesh(cubeGeometry, cubeMaterial); + const edges = new THREE.LineSegments(edgesGeometry, lineMaterial); + cube.add(edges); + return cube; + } + + function createTensorPositions(dimensions, offset) { + const positions = []; + for (let i = 0; i < dimensions[0]; i++) { + for (let j = 0; j < dimensions[1]; j++) { + const x = (j + offset[1]) * (CUBE_SIZE + GAP); + const y = -(i + offset[0]) * (CUBE_SIZE + GAP); + const z = offset[2] * (CUBE_SIZE + GAP); + positions.push(new THREE.Vector3(x, y, z)); + } + } + return positions; + } + + // Create tensors + const globalPositions = createTensorPositions(op['global_shape'], [0, 0, 0]); + const leftSlicePositions = createTensorPositions(op['slice_shape'], [-20, 32, 1]); + + const globalTensor = new THREE.Group(); + const leftSliceTensor = new THREE.Group(); + + globalPositions.forEach((position, index) => { + const i = Math.floor(index / op['global_shape'][1]); + const j = index % op['global_shape'][1]; + const isInSlice = i >= op['slice_coords'][0][0] && i <= op['slice_coords'][2][0] && + j >= op['slice_coords'][0][1] && j <= op['slice_coords'][1][1]; + const color = isInSlice ? COLOR_SLICE : COLOR_GLOBAL; + const cube = createCube(color); + cube.position.copy(position); + globalTensor.add(cube); + }); + + leftSlicePositions.forEach((position) => { + const cube = createCube(COLOR_LEFT_SLICE); + cube.position.copy(position); + leftSliceTensor.add(cube); + }); + + scene.add(globalTensor); + scene.add(leftSliceTensor); + + // Animation + let isPaused = false; + let frame = 0; + const totalFrames = op['slice_shape'][0] * op['slice_shape'][1] * 2 + 30; + + function interpolateColor(color1, color2, factor) { + return new THREE.Color().lerpColors(color1, color2, factor); + } + + function animate() { + requestAnimationFrame(animate); + + if (!isPaused && frame < totalFrames) { + const [i, j] = [Math.floor(frame / 2 / op['slice_shape'][1]), Math.floor(frame / 2) % op['slice_shape'][1]]; + const factor = (frame % 2) / 1.0; + + const global_i = op['slice_coords'][0][0] + i; + const global_j = op['slice_coords'][0][1] + j; + if (global_i < op['global_shape'][0] && global_j < op['global_shape'][1]) { + const globalIndex = global_i * op['global_shape'][1] + global_j; + globalTensor.children[globalIndex].material.color.copy( + interpolateColor(COLOR_SLICE, COLOR_LOADED, factor) + ); + } + + const leftSliceIndex = i * op['slice_shape'][1] + j; + leftSliceTensor.children[leftSliceIndex].material.color.copy( + interpolateColor(COLOR_LEFT_SLICE, COLOR_LOADED, factor) + ); + + frame++; + } + + renderer.render(scene, camera); + } + + // Handle window resize + function onWindowResize() { + camera.aspect = containerElement.clientWidth / containerElement.clientHeight; + camera.updateProjectionMatrix(); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + } + + // Camera controls + const cameraSpeed = 0.1; + const zoomSpeed = 0.5; + + function onKeyDown(event) { + switch(event.key) { + case ' ': + isPaused = !isPaused; + break; + case 'ArrowLeft': + camera.position.x -= cameraSpeed; + break; + case 'ArrowRight': + camera.position.x += cameraSpeed; + break; + case 'ArrowUp': + camera.position.y += cameraSpeed; + break; + case 'ArrowDown': + camera.position.y -= cameraSpeed; + break; + case 'o': + camera.position.z -= zoomSpeed; + break; + case 'p': + camera.position.z += zoomSpeed; + break; + } + camera.lookAt(8, -5, 0); + } + + containerElement.addEventListener('keydown', onKeyDown); + window.addEventListener('resize', onWindowResize); + + // Start animation + animate(); + + // Return cleanup function + return function cleanup() { + window.removeEventListener('resize', onWindowResize); + containerElement.removeEventListener('keydown', onKeyDown); + containerElement.removeChild(renderer.domElement); + }; +} \ No newline at end of file diff --git a/triton_viz/static/matmul.js b/triton_viz/static/matmul.js new file mode 100644 index 0000000..8a126a1 --- /dev/null +++ b/triton_viz/static/matmul.js @@ -0,0 +1,209 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; + + + +export function createMatMulVisualization(containerElement, op) { + const { input_shape, other_shape, output_shape } = op; + + const CUBE_SIZE = 0.2; + const GAP = 0.05; + + // Colors + const COLOR_A = new THREE.Color(0.53, 0.81, 0.98); // Light Blue + const COLOR_B = new THREE.Color(1.0, 0.65, 0.0); // Orange + const COLOR_C = new THREE.Color(1.0, 1.0, 1.0); // White + const COLOR_HIGHLIGHT = new THREE.Color(0.0, 0.0, 1.0); // Blue (for highlighting) + const COLOR_FILLED = new THREE.Color(0.0, 0.0, 1.0); // Blue (for filled elements in C) + const COLOR_BACKGROUND = new THREE.Color(0, 0, 0); // Black + const COLOR_EDGE = new THREE.Color(0.3, 0.3, 0.3); // Light Gray (for cube edges) + + // Scene setup + const scene = new THREE.Scene(); + scene.background = COLOR_BACKGROUND; + const camera = new THREE.PerspectiveCamera(45, containerElement.clientWidth / containerElement.clientHeight, 0.1, 1000); + + const renderer = new THREE.WebGLRenderer({ antialias: true }); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + containerElement.appendChild(renderer.domElement); + + // Lighting + const ambientLight = new THREE.AmbientLight(0xffffff, 0.5); + scene.add(ambientLight); + const directionalLight = new THREE.DirectionalLight(0xffffff, 0.5); + directionalLight.position.set(10, 10, 10); + scene.add(directionalLight); + + // Create cube geometry and materials + const cubeGeometry = new THREE.BoxGeometry(CUBE_SIZE, CUBE_SIZE, CUBE_SIZE); + const edgesGeometry = new THREE.EdgesGeometry(cubeGeometry); + const lineMaterial = new THREE.LineBasicMaterial({ color: COLOR_EDGE }); + + function createCube(color) { + const cubeMaterial = new THREE.MeshPhongMaterial({ color: color }); + const cube = new THREE.Mesh(cubeGeometry, cubeMaterial); + const edges = new THREE.LineSegments(edgesGeometry, lineMaterial); + cube.add(edges); + return cube; + } + + function createMatrix(dimensions, position, color) { + const matrix = new THREE.Group(); + matrix.userData.dimensions = dimensions; + for (let i = 0; i < dimensions[0]; i++) { + for (let j = 0; j < dimensions[1]; j++) { + const cube = createCube(color); + cube.position.set( + position.x + j * (CUBE_SIZE + GAP), + position.y - i * (CUBE_SIZE + GAP), + position.z + ); + matrix.add(cube); + } + } + return matrix; + } + + // Create matrices + const matrixA = createMatrix(input_shape, new THREE.Vector3(-10, 10, 0), COLOR_A); + const matrixB = createMatrix(other_shape, new THREE.Vector3(0, 10, 0), COLOR_B); + const matrixC = createMatrix(output_shape, new THREE.Vector3(-5, -4, 0), COLOR_C); + + scene.add(matrixA); + scene.add(matrixB); + scene.add(matrixC); + + // Center camera + const center = new THREE.Vector3(); + const size = new THREE.Vector3(); + const box = new THREE.Box3().setFromObject(scene); + box.getCenter(center); + box.getSize(size); + const maxDim = Math.max(size.x, size.y, size.z); + const fov = camera.fov * (Math.PI / 180); + let cameraZ = Math.abs(maxDim / 2 / Math.tan(fov / 2)); + cameraZ *= 1.5; // Zoom out a little so objects don't fill the screen + + camera.position.set(center.x, center.y, center.z + cameraZ); + camera.lookAt(center); + + // Animation control + let isPaused = false; + let frame = 0; + const totalFrames = input_shape[0] * other_shape[1]; + + function highlightCubes(matrix, indices, highlightColor) { + indices.forEach(([i, j]) => { + if (i >= 0 && i < matrix.userData.dimensions[0] && j >= 0 && j < matrix.userData.dimensions[1]) { + const index = i * matrix.userData.dimensions[1] + j; + if (index < matrix.children.length) { + matrix.children[index].material.color.copy(highlightColor); + } + } + }); + } + + function resetColors() { + matrixA.children.forEach(cube => cube.material.color.copy(COLOR_A)); + matrixB.children.forEach(cube => cube.material.color.copy(COLOR_B)); + } + + function animate() { + requestAnimationFrame(animate); + + if (!isPaused && frame < totalFrames) { + resetColors(); + + const row = Math.floor(frame / other_shape[1]); + const col = frame % other_shape[1]; + + // Highlight entire row in A and entire column in B + const highlightA = Array.from({ length: input_shape[1] }, (_, i) => [row, i]); + const highlightB = Array.from({ length: other_shape[0] }, (_, i) => [i, col]); + const highlightC = [[row, col]]; + + highlightCubes(matrixA, highlightA, COLOR_HIGHLIGHT); + highlightCubes(matrixB, highlightB, COLOR_HIGHLIGHT); + highlightCubes(matrixC, highlightC, COLOR_FILLED); + + frame++; + } + + renderer.render(scene, camera); + } + + // Handle container resize + function onResize() { + camera.aspect = containerElement.clientWidth / containerElement.clientHeight; + camera.updateProjectionMatrix(); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + } + + // Create control panel + const controlPanel = document.createElement('div'); + controlPanel.style.position = 'absolute'; + controlPanel.style.bottom = '10px'; + controlPanel.style.left = '10px'; + controlPanel.style.display = 'flex'; + controlPanel.style.gap = '10px'; + + // Add controls + const playPauseButton = document.createElement('button'); + playPauseButton.textContent = 'Play/Pause'; + playPauseButton.addEventListener('click', () => { + isPaused = !isPaused; + }); + + const resetButton = document.createElement('button'); + resetButton.textContent = 'Reset'; + resetButton.addEventListener('click', () => { + frame = 0; + resetColors(); + }); + + const closeButton = document.createElement('button'); + closeButton.textContent = 'Close'; + closeButton.addEventListener('click', () => { + cleanup(); + }); + + controlPanel.appendChild(playPauseButton); + controlPanel.appendChild(resetButton); + controlPanel.appendChild(closeButton); + + containerElement.appendChild(controlPanel); + + // Keyboard controls + function onKeyDown(event) { + switch (event.key) { + case 'ArrowLeft': + camera.position.x -= 0.5; + break; + case 'ArrowRight': + camera.position.x += 0.5; + break; + case 'ArrowUp': + camera.position.y += 0.5; + break; + case 'ArrowDown': + camera.position.y -= 0.5; + break; + } + camera.lookAt(center); + } + + window.addEventListener('resize', onResize); + window.addEventListener('keydown', onKeyDown); + + // Start animation + animate(); + + // Cleanup function + function cleanup() { + window.removeEventListener('resize', onResize); + window.removeEventListener('keydown', onKeyDown); + containerElement.innerHTML = ''; + } + + // Return cleanup function + return cleanup; +} \ No newline at end of file diff --git a/triton_viz/static/store.js b/triton_viz/static/store.js new file mode 100644 index 0000000..1f23b95 --- /dev/null +++ b/triton_viz/static/store.js @@ -0,0 +1,171 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; + +export function createStoreVisualization(containerElement, op) { + const CUBE_SIZE = 0.2; + const GAP = 0.05; + + // Colors + const COLOR_GLOBAL = new THREE.Color(0.2, 0.2, 0.2); // Dark Gray + const COLOR_SLICE = new THREE.Color(0.0, 0.7, 1.0); // Cyan (starting color for global slice) + const COLOR_LEFT_SLICE = new THREE.Color(0.0, 1.0, 0.0); // Bright Green (starting color for left slice) + const COLOR_STORED = new THREE.Color(1.0, 0.0, 1.0); // Magenta (final color for left slice) + const COLOR_LOADED = new THREE.Color(1.0, 0.8, 0.0); // Gold (final color for global tensor) + const COLOR_BACKGROUND = new THREE.Color(0.0, 0.0, 0.0); // Black + const COLOR_EDGE = new THREE.Color(0.5, 0.5, 0.5); // Gray (for cube edges) + + // Scene setup + const scene = new THREE.Scene(); + scene.background = COLOR_BACKGROUND; + const camera = new THREE.PerspectiveCamera(45, containerElement.clientWidth / containerElement.clientHeight, 0.1, 1000); + camera.position.set(15, -15, 30); + camera.lookAt(8, -5, 0); + + const renderer = new THREE.WebGLRenderer({ antialias: true }); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + containerElement.appendChild(renderer.domElement); + + // Lighting + const ambientLight = new THREE.AmbientLight(0xffffff, 0.3); + scene.add(ambientLight); + const directionalLight = new THREE.DirectionalLight(0xffffff, 0.7); + directionalLight.position.set(10, 10, 10); + scene.add(directionalLight); + + // Create cube geometry and materials + const cubeGeometry = new THREE.BoxGeometry(CUBE_SIZE, CUBE_SIZE, CUBE_SIZE); + const edgesGeometry = new THREE.EdgesGeometry(cubeGeometry); + const lineMaterial = new THREE.LineBasicMaterial({ color: COLOR_EDGE }); + + function createCube(color) { + const cubeMaterial = new THREE.MeshPhongMaterial({ color: color }); + const cube = new THREE.Mesh(cubeGeometry, cubeMaterial); + const edges = new THREE.LineSegments(edgesGeometry, lineMaterial); + cube.add(edges); + return cube; + } + + function createTensorPositions(dimensions, offset) { + const positions = []; + for (let i = 0; i < dimensions[0]; i++) { + for (let j = 0; j < dimensions[1]; j++) { + const x = (j + offset[1]) * (CUBE_SIZE + GAP); + const y = -(i + offset[0]) * (CUBE_SIZE + GAP); + const z = offset[2] * (CUBE_SIZE + GAP); + positions.push(new THREE.Vector3(x, y, z)); + } + } + return positions; + } + + // Create tensors + const globalPositions = createTensorPositions(op['global_shape'], [0, 0, 0]); + const leftSlicePositions = createTensorPositions(op['slice_shape'], [-20, 32, 1]); + + const globalTensor = new THREE.Group(); + const leftSliceTensor = new THREE.Group(); + + globalPositions.forEach((position, index) => { + const i = Math.floor(index / op['global_shape'][1]); + const j = index % op['global_shape'][1]; + const isInSlice = i >= op['slice_coords'][0][0] && i <= op['slice_coords'][2][0] && + j >= op['slice_coords'][0][1] && j <= op['slice_coords'][1][1]; + const color = isInSlice ? COLOR_SLICE : COLOR_GLOBAL; + const cube = createCube(color); + cube.position.copy(position); + globalTensor.add(cube); + }); + + leftSlicePositions.forEach((position) => { + const cube = createCube(COLOR_LEFT_SLICE); + cube.position.copy(position); + leftSliceTensor.add(cube); + }); + + scene.add(globalTensor); + scene.add(leftSliceTensor); + + // Animation + let frame = 0; + const totalFrames = op['slice_shape'][0] * op['slice_shape'][1] * 2 + 30; + + function interpolateColor(color1, color2, factor) { + return new THREE.Color().lerpColors(color1, color2, factor); + } + + function animate() { + requestAnimationFrame(animate); + + if (frame < op['slice_shape'][0] * op['slice_shape'][1] * 2) { + const i = Math.floor(frame / 2 / op['slice_shape'][1]); + const j = Math.floor(frame / 2) % op['slice_shape'][1]; + const factor = (frame % 2) / 1.0; + + const global_i = op['slice_coords'][0][0] + i; + const global_j = op['slice_coords'][0][1] + j; + if (global_i < op['global_shape'][0] && global_j < op['global_shape'][1]) { + const globalIndex = global_i * op['global_shape'][1] + global_j; + globalTensor.children[globalIndex].material.color.copy( + interpolateColor(COLOR_SLICE, COLOR_LOADED, factor) + ); + } + + const leftSliceIndex = i * op['slice_shape'][1] + j; + leftSliceTensor.children[leftSliceIndex].material.color.copy( + interpolateColor(COLOR_LEFT_SLICE, COLOR_STORED, factor) + ); + + frame++; + } + + renderer.render(scene, camera); + } + + // Handle window resize + function onWindowResize() { + camera.aspect = containerElement.clientWidth / containerElement.clientHeight; + camera.updateProjectionMatrix(); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + } + + window.addEventListener('resize', onWindowResize); + + // Camera controls + const cameraSpeed = 0.1; + const zoomSpeed = 0.5; + + function onKeyDown(event) { + switch(event.key) { + case 'ArrowLeft': + camera.position.x -= cameraSpeed; + break; + case 'ArrowRight': + camera.position.x += cameraSpeed; + break; + case 'ArrowUp': + camera.position.y += cameraSpeed; + break; + case 'ArrowDown': + camera.position.y -= cameraSpeed; + break; + case 'o': + camera.position.z -= zoomSpeed; + break; + case 'p': + camera.position.z += zoomSpeed; + break; + } + camera.lookAt(8, -5, 0); + } + + containerElement.addEventListener('keydown', onKeyDown); + + // Start animation + animate(); + + // Return cleanup function + return function cleanup() { + window.removeEventListener('resize', onWindowResize); + containerElement.removeEventListener('keydown', onKeyDown); + containerElement.removeChild(renderer.domElement); + }; +} \ No newline at end of file diff --git a/triton_viz/static/three.js b/triton_viz/static/three.js new file mode 100644 index 0000000..e69de29 diff --git a/triton_viz/static/visualization.js b/triton_viz/static/visualization.js new file mode 100644 index 0000000..05ebb25 --- /dev/null +++ b/triton_viz/static/visualization.js @@ -0,0 +1,357 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; +import { GridBlock } from './gridblock.js'; +import { createLoadVisualization } from './load.js'; +import { createStoreVisualization } from './store.js'; +import { createMatMulVisualization } from './matmul.js'; + +let globalData; +let currentView = 'main'; +let canvas, ctx; +let maxX = 0, maxY = 0, maxZ = 0; +let sliders = [], zSlider, precomputeButton, kernelGrid; +let backButton; +let currentBlockData = null; +let isInitialized = false; +let containerElement; + +// Initialize the application +function switchToMainView() { + currentView = 'main'; + if (currentBlockData) { + currentBlockData.hideDetailedView(); + currentBlockData = null; + } + containerElement.style.pointerEvents = 'none'; + containerElement.style.display = 'none'; // Hide the container + containerElement.innerHTML = ''; // Clear the container + + // Show the main canvas + canvas.style.display = 'block'; + draw(); +} + + +function initializeApp() { + canvas = document.getElementById('canvas'); + if (!canvas) { + console.error('Canvas element not found'); + return; + } + ctx = canvas.getContext('2d'); + if (!ctx) { + console.error('Unable to get 2D context from canvas'); + return; + } + canvas.width = window.innerWidth; + canvas.height = window.innerHeight; + + // Add event listeners + canvas.addEventListener('mousedown', handleMouseEvent); + canvas.addEventListener('mouseup', handleMouseEvent); + canvas.addEventListener('mousemove', handleMouseEvent); + + containerElement = document.getElementById('visualization-container'); + if (!containerElement) { + console.error('Visualization container element not found'); + return; + } + + // Set up the visualization container + containerElement.style.pointerEvents = 'none'; + containerElement.style.display = 'none'; // Initially hide the container + + fetchData(); +} + +// Slider class +class Slider { + constructor(x, y, width, height, label, min_value = 0, max_value = 100) { + this.rect = { x, y, width, height }; + this.label = label; + this.min = min_value; + this.max = max_value; + this.value = min_value; + this.grabbed = false; + this.enabled = true; + } + + draw(ctx) { + if (!this.enabled) return; + ctx.fillStyle = '#3c3c46'; + ctx.fillRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + const buttonX = this.rect.x + (this.value - this.min) / (this.max - this.min) * this.rect.width; + ctx.fillStyle = '#c8c8c8'; + ctx.fillRect(buttonX - 5, this.rect.y - 2, 10, this.rect.height + 4); + + ctx.fillStyle = '#c8c8c8'; + ctx.font = '18px Arial'; + ctx.fillText(this.label, this.rect.x, this.rect.y - 10); + ctx.fillText(this.value.toString(), this.rect.x + this.rect.width + 10, this.rect.y + this.rect.height / 2 + 5); + } + + handleEvent(event) { + if (!this.enabled) return; + if (event.type === 'mousedown') { + if (this.isPointInside(event.offsetX, event.offsetY)) { + this.grabbed = true; + } + } else if (event.type === 'mouseup') { + this.grabbed = false; + } else if (event.type === 'mousemove' && this.grabbed) { + const mouseX = event.offsetX; + this.value = Math.round((mouseX - this.rect.x) / this.rect.width * (this.max - this.min) + this.min); + this.value = Math.max(this.min, Math.min(this.max, this.value)); + } + } + + isPointInside(x, y) { + return x >= this.rect.x && x <= this.rect.x + this.rect.width && + y >= this.rect.y && y <= this.rect.y + this.rect.height; + } +} + +// Button class +class Button { + constructor(x, y, width, height, text) { + this.rect = { x, y, width, height }; + this.text = text; + this.color = '#3c3c46'; + this.hoverColor = '#50505a'; + this.clickColor = '#64646e'; + this.isHovered = false; + this.isClicked = false; + this.clickTime = 0; + } + + draw(ctx) { + let color = this.color; + if (this.isClicked && Date.now() - this.clickTime < 100) { + color = this.clickColor; + } else if (this.isHovered) { + color = this.hoverColor; + } + + ctx.fillStyle = color; + ctx.fillRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + ctx.strokeStyle = '#c8c8c8'; + ctx.strokeRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + + ctx.fillStyle = '#c8c8c8'; + ctx.font = '18px Arial'; + ctx.textAlign = 'center'; + ctx.textBaseline = 'middle'; + ctx.fillText(this.text, this.rect.x + this.rect.width / 2, this.rect.y + this.rect.height / 2); + ctx.textAlign = 'left'; + ctx.textBaseline = 'alphabetic'; + } + + handleEvent(event) { + const { offsetX, offsetY } = event; + this.isHovered = this.isPointInside(offsetX, offsetY); + if (event.type === 'mousedown' && this.isHovered) { + this.isClicked = true; + this.clickTime = Date.now(); + console.log(`Button '${this.text}' clicked!`); + } else if (event.type === 'mouseup') { + this.isClicked = false; + } + } + + isPointInside(x, y) { + return x >= this.rect.x && x <= this.rect.x + this.rect.width && + y >= this.rect.y && y <= this.rect.y + this.rect.height; + } +} + +// KernelGrid class +class KernelGrid { + constructor(x, y, width, height, gridSize, visualizationData) { + this.rect = { x, y, width, height }; + this.gridSize = gridSize; + this.visualizationData = visualizationData; + this.currentZ = 0; + this.blocks = []; + this.calculateBlockSize(); + this.createBlocks(); + this.selectedBlock = null; + } + + calculateBlockSize() { + this.blockWidth = Math.floor(this.rect.width / this.gridSize[0]) - 1; + this.blockHeight = Math.floor(this.rect.height / this.gridSize[1]) - 1; + } + + createBlocks() { + this.blocks = []; + for (let y = 0; y < this.gridSize[1]; y++) { + for (let x = 0; x < this.gridSize[0]; x++) { + const blockX = this.rect.x + x * (this.blockWidth + 1); + const blockY = this.rect.y + y * (this.blockHeight + 1); + const gridKey = `(${x}, ${y}, ${this.currentZ})`; + const blockData = this.visualizationData[gridKey] || []; + const block = new GridBlock(blockX, blockY, this.blockWidth, this.blockHeight, x, y, this.currentZ, blockData); + this.blocks.push(block); + } + } + } + + draw(ctx) { + ctx.fillStyle = '#F0F0F0'; + ctx.fillRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + this.blocks.forEach(block => block.draw(ctx)); + } + + updateZ(z) { + this.currentZ = z; + this.blocks.forEach(block => { + block.gridPosition.z = z; + const gridKey = `(${block.gridPosition.x}, ${block.gridPosition.y}, ${z})`; + block.blockData = this.visualizationData[gridKey] || []; + }); + } + + handleClick(x, y, containerElement) { + const clickedBlock = this.blocks.find(block => block.isPointInside(x, y)); + if (clickedBlock) { + console.log(`Clicked block at (${clickedBlock.gridPosition.x}, ${clickedBlock.gridPosition.y}, ${clickedBlock.gridPosition.z})`); + if (this.selectedBlock) { + this.selectedBlock.hideDetailedView(); + } + this.selectedBlock = clickedBlock; + clickedBlock.showDetailedView(containerElement); + return clickedBlock; + } + return null; + } + + handleMouseMove(x, y) { + this.blocks.forEach(block => block.handleMouseMove(x, y)); + } +} + +// Function to determine the max values for sliders based on visualization data +function determineMaxValues(visualizationData) { + maxX = 0; + maxY = 0; + maxZ = 0; + const keys = Object.keys(visualizationData); + keys.forEach(key => { + const [x, y, z] = key.replace(/[()]/g, '').split(', ').map(Number); + if (x > maxX) maxX = x; + if (y > maxY) maxY = y; + if (z > maxZ) maxZ = z; + }); +} + +// Function to initialize UI elements +function initializeUIElements() { + sliders = [ + new Slider(1300, 50, 250, 20, "Program Id 0", 0, maxX), + new Slider(1300, 120, 250, 20, "Program Id 1", 0, maxY), + new Slider(1300, 190, 250, 20, "Program Id 2", 0, maxZ) + ]; + + zSlider = new Slider(50, 860, 1200, 20, "Z-axis", 0, maxZ); + zSlider.enabled = maxZ > 0; + + precomputeButton = new Button(1300, 260, 250, 40, "Precompute"); + kernelGrid = new KernelGrid(50, 50, 1200, 800, [maxX + 1, maxY + 1, maxZ + 1], globalData.ops.visualization_data); + backButton = new Button(50, 50, 100, 40, "Back"); + isInitialized = true; +} +function switchToTensorView(clickedBlock) { + currentView = 'tensor'; + currentBlockData = clickedBlock; + console.log("Switched to tensor view. Block data:", clickedBlock); + + containerElement.style.pointerEvents = 'auto'; + containerElement.style.display = 'block'; // Make sure the container is visible + clickedBlock.showDetailedView(containerElement); + + // Hide the main canvas + canvas.style.display = 'none'; +} + + + + +function handleMouseEvent(event) { + if (!isInitialized) { + console.warn('UI elements not initialized yet'); + return; + } + + const { offsetX, offsetY } = event; + if (currentView === 'main') { + sliders.forEach(slider => slider.handleEvent(event)); + if (zSlider && zSlider.enabled) { + zSlider.handleEvent(event); + if (kernelGrid) { + kernelGrid.updateZ(zSlider.value); + } + } + if (precomputeButton) { + precomputeButton.handleEvent(event); + } + if (kernelGrid) { + kernelGrid.handleMouseMove(offsetX, offsetY); + if (event.type === 'mousedown') { + const clickedBlock = kernelGrid.handleClick(offsetX, offsetY, containerElement); + if (clickedBlock) { + switchToTensorView(clickedBlock); + } + } + } + } else if (currentView === 'tensor') { + if (backButton) { + backButton.handleEvent(event); + if (event.type === 'mousedown' && backButton.isHovered) { + switchToMainView(); + } + } + } + draw(); +} + +function draw() { + if (!ctx) { + console.error('Canvas context not available'); + return; + } + + ctx.fillStyle = '#1e1e28'; + ctx.fillRect(0, 0, canvas.width, canvas.height); + + if (currentView === 'main') { + if (kernelGrid) kernelGrid.draw(ctx); + sliders.forEach(slider => slider.draw(ctx)); + if (zSlider && zSlider.enabled) { + zSlider.draw(ctx); + } + if (precomputeButton) precomputeButton.draw(ctx); + } + // Remove the else clause for 'tensor' view, as it's now handled by Three.js +} + + + + +// Fetch data from the API and initialize the visualization +async function fetchData() { + + const response = await fetch('/api/data'); + globalData = await response.json(); + + // Determine max values for sliders based on fetched data + determineMaxValues(globalData.ops.visualization_data); + + // Initialize UI elements + initializeUIElements(); + + // Initial draw + draw(); + +} + +// Start the application +initializeApp(); \ No newline at end of file diff --git a/triton_viz/templates/index.html b/triton_viz/templates/index.html new file mode 100644 index 0000000..3a6db82 --- /dev/null +++ b/triton_viz/templates/index.html @@ -0,0 +1,47 @@ + + + + + + Kernel Launch Configuration Visualization + + + + +
+ + + + \ No newline at end of file diff --git a/triton_viz/tooltip.py b/triton_viz/tooltip.py index 2a3f069..5328c0c 100644 --- a/triton_viz/tooltip.py +++ b/triton_viz/tooltip.py @@ -7,12 +7,15 @@ "ExpandDims": "Shows the total number of expand_dims operations performed in this kernel.", "Dot": "Shows the total number of dot operations performed in this kernel.", "Reduce": "Shows the total number of reduce operations performed in this kernel.", - "Total number of bytes loaded": "Shows the total number of bytes loaded (mask=True).Note:On GPUs, this metric does not equate to the total number of bytes loaded from global memory (DRAM), as some data accesses may be handled through GPU caches.", + "Total number of bytes loaded": "Shows the total number of bytes loaded (mask=True). Note: On GPUs, this metric does not equate to the total number of bytes loaded from global memory (DRAM), as some data accesses may be handled through GPU caches.", "Masked Load Ratio": "Ratio of total number of bytes loaded (mask=True)/total number of bytes loaded (mask=True) + (mask=False).", "Total number of bytes stored": "Shows the total number of bytes stored (mask=True).", "Masked Store Ratio": "Ratio of total number of bytes stored (mask=True)/total number of bytes stored (mask=True) + (mask=False).", } +def get_tooltip_data(df): + """Return the tooltip data in a format suitable for JSON serialization.""" + return df.to_dict() def create_tooltip(df): styles = """ diff --git a/triton_viz/trace.py b/triton_viz/trace.py index bf51d37..7ed69b5 100644 --- a/triton_viz/trace.py +++ b/triton_viz/trace.py @@ -8,12 +8,17 @@ class Trace(KernelInterface): def __init__(self, kernel: JITFunction) -> None: + self.src =kernel.src assert isinstance(kernel, JITFunction), "Kernel must be a JITFunction" + self._fn = InterpretedFunction(kernel.fn) + def run(self, *args, **kwargs): with patch(): return self._fn.run(*args, **kwargs) + def get_src(self): + return self.src def trace(kernel):