diff --git a/examples/3dims.py b/examples/3dims.py new file mode 100644 index 0000000..bfe766c --- /dev/null +++ b/examples/3dims.py @@ -0,0 +1,99 @@ +import torch +import triton +import triton.language as tl + + + +import triton_viz +BLOCK_SIZE_X = 32 +BLOCK_SIZE_Y = 8 +BLOCK_SIZE_Z = 4 + +@triton_viz.trace +@triton.jit +def add_3d_slices_kernel( + input_ptr1, input_ptr2, output_ptr, + stride_x, stride_y, stride_z, + slice_x, slice_y, slice_z, + BLOCK_SIZE_X: tl.constexpr, BLOCK_SIZE_Y: tl.constexpr, BLOCK_SIZE_Z: tl.constexpr +): + # Compute the 3D position in the output tensor + pid_x = tl.program_id(0) + pid_y = tl.program_id(1) + pid_z = tl.program_id(2) + + # Compute the starting position for this block + x_start = pid_x * BLOCK_SIZE_X + y_start = pid_y * BLOCK_SIZE_Y + z_start = pid_z * BLOCK_SIZE_Z + + # Compute offsets within the block + x_offsets = x_start + tl.arange(0, BLOCK_SIZE_X) + y_offsets = y_start + tl.arange(0, BLOCK_SIZE_Y) + z_offsets = z_start + tl.arange(0, BLOCK_SIZE_Z) + + # Create a mask to handle boundary conditions + mask = (x_offsets < slice_x) & (y_offsets < slice_y)[:, None] & (z_offsets < slice_z)[:, None, None] + + # Compute the input and output offsets + offsets = ( + z_offsets[:, None, None] * stride_z + + y_offsets[:, None] * stride_y + + x_offsets * stride_x + ) + + # Load input slices + slice1 = tl.load(input_ptr1 + offsets, mask=mask) + slice2 = tl.load(input_ptr2 + offsets, mask=mask) + + # Perform addition + result = slice1 + slice2 + + # Store the result + tl.store(output_ptr + offsets, result, mask=mask) + +def add_3d_slices(input1, input2, output): + # Get tensor shapes + slice_z, slice_y, slice_x = input1.shape + + # Compute strides + stride_z, stride_y, stride_x = input1.stride() + + # Determine grid size + grid = ( + triton.cdiv(slice_x, BLOCK_SIZE_X), + triton.cdiv(slice_y, BLOCK_SIZE_Y), + triton.cdiv(slice_z, BLOCK_SIZE_Z) + ) + + # Launch kernel + add_3d_slices_kernel[grid]( + input1, input2, output, + stride_x, stride_y, stride_z, + slice_x, slice_y, slice_z, + BLOCK_SIZE_X=BLOCK_SIZE_X, + BLOCK_SIZE_Y=BLOCK_SIZE_Y, + BLOCK_SIZE_Z=BLOCK_SIZE_Z + ) + +if __name__ == "__main__": + # Set random seed for reproducibility + torch.manual_seed(0) + + # Create example input tensor + input1 = torch.randn(16, 16, 32, device='cpu') + input2 = torch.randn(16, 16, 32, device='cpu') + output = torch.empty_like(input1) + + # Call the kernel + add_3d_slices(input1, input2, output) + triton_viz.launch() + + # Verify the result + expected_output = input1 + input2 + assert torch.allclose(output, expected_output), "Kernel output does not match expected result" + + + + + diff --git a/setup.py b/setup.py index 6898fed..0ea16c3 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="triton-viz", - version="0.1", + version="0.2", packages=find_packages(), description="A visualization tool for Triton", author="Deep Learning Profiling Tools Team", @@ -11,10 +11,11 @@ install_requires=[ "setuptools", "triton", - "gradio", - "chalk-diagrams @ git+https://github.com/chalk-diagrams/chalk.git", + "flask", "pyarrow", "pre-commit", "pytest", + "flask_cloudflared", + "requests", ], ) diff --git a/triton_viz/__init__.py b/triton_viz/__init__.py index 737dd36..7c5bab8 100644 --- a/triton_viz/__init__.py +++ b/triton_viz/__init__.py @@ -1,5 +1,5 @@ from .trace import trace, dump, sample -from .draw import collect_grid, draw_record +from .draw import collect_grid from .interface import launch -__all__ = ["trace", "launch", "dump", "sample", "collect_grid", "draw_record"] +__all__ = ["trace", "launch", "dump", "sample", "collect_grid"] diff --git a/triton_viz/data.py b/triton_viz/data.py index 819a398..ad9678c 100644 --- a/triton_viz/data.py +++ b/triton_viz/data.py @@ -1,9 +1,11 @@ from dataclasses import dataclass, field -from typing import List, Tuple, Any +from typing import List, Tuple, Any, Dict import traceback import numpy.typing as npt import numpy as np +import torch + @dataclass class Op: @@ -73,6 +75,15 @@ class Dot(Op): input_shape: Tuple other_shape: Tuple output_shape: Tuple + input_data: List[List[float]] + other_data: List[List[float]] + intermediate_results: Dict[Tuple[int, int], float] = field( + default_factory=dict + ) # Only storing the result now + + def update_intermediate(self, row: int, col: int, result: float): + # Store only the result as a float + self.intermediate_results[(row, col)] = result @dataclass @@ -91,6 +102,7 @@ class Tensor: stride: Tuple shape: Tuple element_size: int + data: torch.Tensor @dataclass diff --git a/triton_viz/draw.py b/triton_viz/draw.py index fda62c0..2de3c52 100644 --- a/triton_viz/draw.py +++ b/triton_viz/draw.py @@ -1,69 +1,10 @@ -from colour import Color -from triton_viz.data import ( - Tensor, - Grid, - Store, - Load, - Op, - MakeRange, - Reduce, - Dot, -) -from .interpreter import record_builder +from triton_viz.data import Tensor, Grid, Store, Load, Dot, ExpandDims +import uuid import numpy as np -import numpy.typing as npt -import planar -import math -import chalk -from typing import Tuple, Union, Optional, List, Dict -from chalk import Diagram, rectangle, text, hcat, vcat, empty, Path, Trail, V2, concat -from dataclasses import dataclass -from numpy.typing import ArrayLike -import sys - -sys.setrecursionlimit(100000) - - -planar.EPSILON = 0.0 -chalk.set_svg_draw_height(500) -BG = Color("white") -WHITE = Color("white") -DEFAULT = Color("grey") -BLACK = Color("black") -GREY = Color("grey") -palette = [ - "#f29f05", - "#f25c05", - "#d6568c", - "#4d8584", - "#a62f03", - "#400d01", - "#274001", - "#828a00", -] -ACTIVE = [Color(p) for p in palette] - -MRATIO = 1 / 3 - - -# Generic render helpers - - -def box(d: Diagram, width: float, height: float, outer=0.2) -> Diagram: - "Put diagram in a box of shape height, width" - h, w = d.get_envelope().height, d.get_envelope().width - m = max(h, w) - back = rectangle(outer + m, outer + m).line_width(0).fill_color(BG).center_xy() - d = (back + d.center_xy()).with_envelope(back) - return d.scale_x(width / m).scale_y(height / m) - - -def reshape(d: Diagram) -> Diagram: - "Use log-scale if ratio is too sharp" - h, w = d.get_envelope().height, d.get_envelope().width - if (h / w > MRATIO) or (w / h > MRATIO): - d = d.scale_y(math.log(h + 1, 2) / h).scale_x(math.log(w + 1, 2) / w) - return d +import torch +from .interpreter import record_builder + +from typing import Tuple, List def collect_grid(): @@ -96,141 +37,37 @@ def collect_launch(launch): return all_grids, tensor_table, failures -def draw_record(program_record, tensor_table, output): - return draw_launch(program_record, tensor_table, output) - - -def draw_launch(program_records, tensor_table, base) -> Diagram: - def draw(x): - "Dispatch" - if isinstance(x, Tensor): - return draw_tensor(x) - if isinstance(x, Grid): - return draw_grid(x) - if isinstance(x, Store): - return draw_store(x, tensor_table) - if isinstance(x, Load): - return draw_load(x, tensor_table) - if isinstance(x, Op): - return draw_op(x) - if isinstance(x, MakeRange): - return draw_make_range(x) - if isinstance(x, Reduce): - return draw_reduce(x) - if isinstance(x, Dot): - return None # draw_dot(x) - - def draw_record(x): - "Render one record" - y = draw(x) - if y is None: - return empty() - - return (chalk.vstrut(0.2) / y).center_xy() - - records = [] - for r in program_records: - dr = draw_record(r) - # env = dr.get_envelope() - # dr = dr.center_xy().with_envelope(rectangle(env.width, env.height).center_xy()) - records.append(dr) - - dr = vcat(records) - dr = dr.center_xy() - env = dr.get_envelope() - dr = rectangle(env.width + 1, env.height + 1).fill_color(BG).center_xy() + dr - dr.render_svg(base, 2500) - return env.width, env.height - - -def delinearize(shape: Tuple, x: npt.NDArray, dtype, mask) -> List[npt.NDArray]: - if len(shape) == 1: - shape = (1, 1, shape[0]) - x = x.copy() // (dtype.element_ty.primitive_bitwidth // 8) - vals = [] - for s in list(reversed(shape[1:])) + [10000]: - vals.append(((x % s) * mask - (1 - mask)).ravel()) - x = x // s - return vals - - -trail = Trail.from_offsets([V2(0, 1), V2(1, 0), V2(0, -1), V2(-1, 0)], closed=True) - - -def cover( - shape: Tuple, dtype, load: Tensor, mask: npt.NDArray, color: Color -) -> Diagram: - shape = make_3d(shape) - "Draw the values from load on top of the loading tensor" - x, y, z = delinearize(shape, load, dtype, mask) - return draw_tensor_3d(shape, z, y, x, color) - - -def pair_draw(x: Diagram, y: Diagram, command: str) -> Diagram: - "Draw two diagrams next to each other with a command in the middle." - return hcat([box(x, 3, 2.5), box(y, 3, 2.5)], 1).center_xy() + text( - command, 0.2 - ).fill_color(BLACK).line_width(0).translate(0, -1) - - -# Individual renderers - - -def draw_tensor(x: Tensor) -> Optional[Diagram]: - return None - - -def draw_grid(x: Grid) -> Optional[Diagram]: - return None - - -def draw_make_range(x: MakeRange) -> Optional[Diagram]: - return None - - -def draw_reduce(x: Reduce) -> Optional[Diagram]: - color = ACTIVE[0] - inp = draw_tensor_3d(make_3d(x.input_shape), None, None, None, color) - if x.index == 0 and len(x.input_shape) == 2: - inp = hcat( - [ - rectangle(0.1, inp.get_envelope().height) - .align_t() - .line_width(0) - .fill_color(BLACK), - inp, - ], - 0.5, - ) - else: - inp = vcat( - [ - rectangle(inp.get_envelope().width, 0.1) - .align_l() - .line_width(0) - .fill_color(BLACK), - inp, - ], - 0.5, - ) - out = draw_tensor_3d(x.output_shape, None, None, None, color) - return pair_draw(reshape(inp), reshape(out), x.op) +def extract_load_coords( + record: Load, global_tensor: Tensor +) -> Tuple[List[Tuple[float, float, float]], List[Tuple[float, float, float]]]: + # Extract coordinates for the global tensor + global_shape = make_3d(global_tensor.shape) + global_z, global_y, global_x = delinearized( + global_shape, + record.original_offsets, + global_tensor.dtype, + record.original_masks, + ) + global_coords = [ + (float(xi), float(yi), float(zi)) + for xi, yi, zi in zip(global_z, global_y, global_x) + if xi != -1 and yi != -1 and zi != -1 + ] -def draw_load(x, tensor_table) -> Optional[Diagram]: - inp, out = store_load(x, tensor_table) - out = reshape(out) - return pair_draw(inp, out, "load") + # Extract coordinates for the slice tensor + slice_shape = make_3d(record.shape) + slice_z, slice_y, slice_x = record.original_masks.reshape(*slice_shape).nonzero() + slice_coords = [ + (float(xi), float(yi), float(zi)) + for xi, yi, zi in zip(slice_x, slice_y, slice_z) + ] -def draw_store(x, tensor_table) -> Optional[Diagram]: - inp, out = store_load(x, tensor_table) - out = reshape(out) - return pair_draw(out, inp, "store") + return global_coords, slice_coords -def make_3d(shape): - "Make a 3d shape" +def make_3d(shape: Tuple[int, ...]): if len(shape) == 1: return (1, 1, shape[0]) if len(shape) == 2: @@ -238,233 +75,114 @@ def make_3d(shape): return shape -def store_load( - x: Union[Store, Load], tensor_table: Dict[int, Tuple[Tensor, int]] -) -> Tuple[Diagram, Diagram]: - tensor, tensor_id = tensor_table[x.ptr] - # inp = base_tensor(tensor.shape, DEFAULT) - color = ACTIVE[tensor_id] - invalid = x.invalid_access_masks.any() - if invalid: - color = Color("red") - inp = cover(tensor.shape, tensor.dtype, x.original_offsets, x.original_masks, color) - inp = reshape(inp) - s = make_3d(x.original_offsets.shape) - a, b, c = x.original_masks.reshape(*s).nonzero() - out = draw_tensor_3d(s, a, b, c, color) - return inp, out - - -def draw_op(x: Op) -> Optional[Diagram]: - return None - - -def draw_dot(x: Dot) -> Optional[Diagram]: - if x.input_shape == (1,): - return None - inp = draw_tensor_3d(x.input_shape[0], None, None, None) - # inp = reshape(base_tensor(x.input_shape[0], color=ACTIVE)) - # inp = add_whiskers(inp, x.input_shape[0]) - inp2 = draw_tensor_3d(x.input_shape[1], None, None, None) - # inp2 = reshape(base_tensor(x.input_shape[0], color=ACTIVE)) - # inp2 = add_whiskers(inp2, x.input_shape[0]) - out = draw_tensor_3d(x.output_shape, None, None, None) - # out = reshape(base_tensor(x.output_shape, color=ACTIVE)) - # out = add_whiskers(out, x.output_shape) - return hcat( - [box(inp, 1.5, 2), box(inp2, 1.5, 2), box(out, 1.5, 2)], 1 - ).center_xy() + text("dot", 0.2).fill_color(BLACK).line_width(0).translate(0, -1) - - -# For 3d - - -def lookAt(eye: ArrayLike, center: ArrayLike, up: ArrayLike): - "Python version of the haskell lookAt function in linear.projections" - f = (center - eye) / np.linalg.norm(center - eye) - s = np.cross(f, up) / np.linalg.norm(np.cross(f, up)) - u = np.cross(s, f) - return np.array([[*s, 0], [*u, 0], [*-f, 0], [0, 0, 0, 1]]) - - -def scale3(x, y, z): - return np.array([[x, 0, 0, 0], [0, y, 0, 0], [0, 0, z, 0], [0, 0, 0, 1]]) - - -@dataclass -class D3: - x: float - y: float - z: float - - def to_np(self): - return np.array([self.x, self.y, self.z]) - - -V3 = D3 - - -def homogeneous(trails: List[List[D3]]): - "Convert list of directions to a np.array of homogeneous coordinates" - return np.array([[[*o.to_np(), 1] for o in offsets] for offsets in trails]) - - -def cube(): - "3 faces of a cube drawn as offsets from the origin." - return homogeneous( - [ - [D3(*v) for v in offset] - for offset in [ - [(1, 0, 0), (0, 1, 0), (-1, 0, 0), (0, -1, 0)], - [(1, 0, 0), (0, 0, 1), (-1, 0, 0), (0, 0, -1)], - [(0, 0, 1), (0, 1, 0), (0, 0, -1), (0, -1, 0)], - ] - ] - ) - - -def to_trail(trail: ArrayLike, locations: ArrayLike): - return [ - ( - Path( - [ - Trail.from_offsets([V2(*v[:2]) for v in trail]) - .close() - .at(V2(*loc[:2])) - ] - ), - loc[2], - ) - for loc in locations - ] - - -def project(projection, shape3, positions): - p = homogeneous([positions for _ in range(shape3.shape[0])]) - locations = p @ projection.T - trails = shape3 @ projection.T - return [out for t, loc in zip(trails, locations) for out in to_trail(t, loc)] - - -def draw_tensor_3d(shape, a, b, c, color=WHITE): - shape = make_3d(shape) - - # Big Cube - s = scale3(*shape) - big_cube = cube() @ s.T - back = scale3(0, shape[1], shape[2]) - back_cube = cube() @ back.T - - # Isometric projection of tensor - projection = lookAt( - V3(-1.0, -0.3, -0.15).to_np(), - V3(0, 0, 0).to_np(), - V3(0, 1, 0).to_np(), - ) - outer = project(projection, big_cube, [V3(0, 0, 0)]) - outer2 = project(projection, back_cube, [V3(shape[0], 0, 0)]) - d = ( - concat([p.stroke().fill_color(GREY).fill_opacity(0.1) for p, _ in outer2]) - .line_width(0.005) - .line_color(GREY) - ) - d += ( - concat([p.stroke().fill_color(GREY).fill_opacity(0.05) for p, _ in outer]) - .line_width(0.01) - .line_color(BLACK) - ) - if a is not None: - out = group(a, b, c) - d2 = [ - (b, loc) - for i in range(len(out)) - for b, loc in make_cube(projection, out[i][0], out[i][1], color) - ] - d2.sort(key=lambda x: x[1], reverse=True) - d2 = concat([b.with_envelope(empty()) for b, _ in d2]) - d = d2.with_envelope(d) + d - return d - - -def lines(s): - "Draw lines to mimic a cube of cubes" - bs = [np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])] - return [ - homogeneous( - [ - [D3(*p) for _ in range(s[i]) for p in [a / s[i], b, -b]] - for b in bs - if not np.all(a == b) - ] +def delinearized( + shape: Tuple[int, int, int], x: np.ndarray, dtype, mask +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + x = x.copy() // (dtype.element_ty.primitive_bitwidth // 8) + z = ((x // (shape[1] * shape[2])) * mask - (1 - mask)).ravel() + y = (((x // shape[2]) % shape[1]) * mask - (1 - mask)).ravel() + x = ((x % shape[2]) * mask - (1 - mask)).ravel() + return z, y, x + + +def prepare_visualization_data(program_records, tensor_table): + """Prepare visualization data for the frntend and raw tensor data for the server.""" + # global idx + visualization_data = [] + raw_tensor_data = {} + for record in program_records: + record_uuid = str(uuid.uuid4())[:8] + + if isinstance(record, ExpandDims): + print(record.input_shape, record.output_shape, record.index) + if isinstance(record, Dot): + visualization_data.append( + { + "type": "Dot", + "input_shape": record.input_shape, + "other_shape": record.other_shape, + "output_shape": record.output_shape, + "uuid": record_uuid, + } + ) + + raw_tensor_data[record_uuid] = { + "input_data": torch.tensor(record.input_data), + "other_data": torch.tensor(record.other_data), + "intermediate_results": record.intermediate_results, + } + + elif isinstance(record, Load): + global_tensor, slice_tensor = tensor_table[record.ptr] + print(global_tensor) + global_coords, slice_coords = extract_load_coords(record, global_tensor) + + visualization_data.append( + { + "type": "Load", + "global_shape": global_tensor.shape, + "slice_shape": record.shape, + "global_coords": global_coords, + "slice_coords": slice_coords, + "uuid": record_uuid, + } + ) + + raw_tensor_data[record_uuid] = { + "global_tensor": global_tensor.data.cpu(), # Ensure it's on CPU + "dims": len(global_tensor.data.cpu().shape), + } + print(record.shape) + + elif isinstance(record, Store): + global_tensor, slice_tensor = tensor_table[record.ptr] + + global_coords, slice_coords = extract_load_coords(record, global_tensor) + + visualization_data.append( + { + "type": "Store", + "global_shape": global_tensor.shape, + "slice_shape": record.shape, + "global_coords": global_coords, + "slice_coords": slice_coords, + "uuid": record_uuid, + } + ) + + return visualization_data, raw_tensor_data, "" + + +def get_visualization_data(): + """Return the visualization data and raw tensor data.""" + records, tensor_table, failures = collect_grid() + visualization_data = {} + raw_tensor_data = {} + + for grid_idx, program_records in records.items(): + viz_data, raw_data, kernel_src = prepare_visualization_data( + program_records, tensor_table ) - for i, a in enumerate(bs) - ] - - -def make_cube(projection, start, end, color): - "Draws a cube from start position to end position." - start = np.array(start).astype(int) - end = np.array(end).astype(int) - s2 = end - start + 1 - s = scale3(*s2) - small_cube = cube() @ s.T - loc = [ - project(projection, l2 @ s.T, [V3(*start)]) - for l2 in lines(s2) - if l2.shape[1] > 0 - ] - outer2 = project(projection, small_cube, [V3(*start)]) - ls = loc - box = [ - (p.stroke().fill_color(color).fill_opacity(0.4).line_width(0), loc) - for p, loc in outer2 - ] - line = [ - (p.stroke().line_width(0.001).line_color(BLACK), l_) - for loc in ls - for p, l_ in loc - ] - return [(b, loc) for b, loc in box + line] - - -def group( - x: ArrayLike, y: ArrayLike, z: ArrayLike -) -> List[Tuple[Tuple[float, float, float], Tuple[float, float, float]]]: - "Groups together cubes into bigger cubes" - x = list(zip(zip(x, y, z), zip(x, y, z))) - x = [(a, b) for a, b in x if not (a[0] == -1 and a[1] == -1 and a[2] == -1)] - - start = x - - def remove_dups(ls): - "Remove duplicates" - out = [] - for y in ls: - if not out or y != out[-1]: - out.append(y) - return out - - for j in range(2, -1, -1): - x = remove_dups(start) - start = [] - while True: - if len(x) <= 1: - break - _, _, rest = x[0], x[1], x[2:] - m = 0 - for k in range(2): - a = x[0][k] - b = x[1][k] - if ( - (k == 0 or a[j % 3] == b[j % 3] - 1) - and a[(j + 1) % 3] == b[(j + 1) % 3] - and a[(j + 2) % 3] == b[(j + 2) % 3] - ): - m += 1 - if m == 2: - x = [[x[0][0], x[1][1]]] + rest - else: - start.append(x[0]) - x = [x[1]] + rest - start += x - return start + visualization_data[str(grid_idx)] = viz_data + raw_tensor_data.update(raw_data) + + # Get the kernel source code + + return { + "visualization_data": visualization_data, + "raw_tensor_data": raw_tensor_data, + "failures": failures, + "kernel_src": kernel_src, + } + + +def serialize_for_json(obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, torch.Tensor): + return obj.cpu().numpy().tolist() + raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") diff --git a/triton_viz/interface.py b/triton_viz/interface.py index 472f01d..e89e29a 100644 --- a/triton_viz/interface.py +++ b/triton_viz/interface.py @@ -1,99 +1,203 @@ -import gradio as gr -import triton_viz -import tempfile +import threading +from flask import Flask, render_template, jsonify, request from .analysis import analyze_records -from .tooltip import create_tooltip +from .draw import get_visualization_data +from .tooltip import get_tooltip_data import pandas as pd +import os +import torch +from flask_cloudflared import _run_cloudflared +import requests +import time +app = Flask( + __name__, + template_folder=os.path.join(os.path.dirname(__file__), "templates"), + static_folder=os.path.join(os.path.dirname(__file__), "static"), +) -def launch(share=True): - cache = {} +# Global variables to store the data +global_data = None +raw_tensor_data = None +precomputed_c_values = {} +current_fullscreen_op = None + + +def precompute_c_values(op_data): + input_data = op_data["input_data"] + other_data = op_data["other_data"] + rows, inner_dim = input_data.shape + cols = other_data.shape[1] + + precomputed = {} + for i in range(rows): + for j in range(cols): + precomputed[(i, j)] = [0] * (inner_dim + 1) + for k in range(1, inner_dim + 1): + precomputed[(i, j)][k] = torch.dot( + input_data[i, :k], other_data[:k, j] + ).item() + + return precomputed + + +def update_global_data(): + global global_data, raw_tensor_data, precomputed_c_values analysis_data = analyze_records() - program_records, tt, failures = triton_viz.collect_grid() - m = [0, 0, 0] - size = [0, 0] - for k in program_records.keys(): - m[0] = max(k[0] + 1, m[0]) - m[1] = max(k[1] + 1, m[1]) - m[2] = max(k[2] + 1, m[2]) - w, h = triton_viz.draw_record(program_records[(0, 0, 0)], tt, "tmp.svg") - size[0] = w - size[1] = h - height = 600 * size[1] / size[0] - with gr.Blocks( - css=".gradio-container button {overflow: auto} img.with-caption {height: %fpx !important; } .thumbnails { display: none; } " - % height - ) as demo: - with gr.Row(): - with gr.Column(scale=3, min_width=500): - img = gr.Gallery( - height=500, - min_width=500, - show_label=False, - selected_index=0, - preview=True, - object_fit="cover", - ) - with gr.Column(scale=1): - s1 = gr.Slider(0, m[0] - 1, value=0, step=1, label="Program Id 0") - s2 = gr.Slider(0, m[1] - 1, value=0, step=1, label="Program Id 1") - s3 = gr.Slider(0, m[2] - 1, value=0, step=1, label="Program Id 2") - b1 = gr.Button("Precompute") - gr.Markdown("## Analysis") - df = pd.DataFrame(analysis_data, columns=["Metric", "Value"]) - analysis_with_tooltip = create_tooltip(df) - gr.HTML(analysis_with_tooltip) - if failures: - gr.Markdown( - show_label=False, - value="## Invalid memory access in " - + "\n * " - + "\n* ".join(list(map(str, failures.keys()))), - ) - - def cache_block(idx): - name = tempfile.NamedTemporaryFile(suffix=".svg") - w, h = triton_viz.draw_record(program_records[idx], tt, name.name) - size[0] = w - size[1] = h - cache[idx] = (name, len(cache)) - - def update(inp): - a = inp[s1] - b = inp[s2] - c = inp[s3] - idx = (a, b, c) - - if idx not in cache: - cache_block(idx) - return gr.Gallery( - value=[(cache[k][0].name, str(k)) for k in cache.keys()], - selected_index=cache[idx][1], - height=700, - ), gr.Slider() - # * size[1]/size[0] - return gr.Gallery(selected_index=cache[idx][1]), gr.Slider() - - def precompute(inp): - a = inp[s1] - b = inp[s2] - c = inp[s3] - idx = (a, b, c) - for i in range(m[0]): - for j in range(m[1]): - for k in range(m[2]): - if (i, j, k) not in cache: - cache_block((i, j, k)) - return gr.Gallery( - value=[(cache[k][0].name, str(k)) for k in cache.keys()], - selected_index=cache[idx][1], - ) - - s1.change(update, inputs={s1, s2, s3}, outputs=[img, b1], show_progress=False) - s2.change(update, inputs={s1, s2, s3}, outputs=[img, b1], show_progress=False) - s3.change(update, inputs={s1, s2, s3}, outputs=[img, b1], show_progress=False) - b1.click(precompute, inputs={s1, s2, s3}, outputs=img, show_progress=True) - demo.load(update, inputs={s1, s2, s3}, outputs=[img, b1]) - - demo.launch(share=share, debug=False, height=800, quiet=True, show_api=False) - return failures + viz_data = get_visualization_data() + global_data = { + "ops": { + "visualization_data": viz_data["visualization_data"], + "failures": viz_data["failures"], + "kernel_src": viz_data["kernel_src"], + } + } + raw_tensor_data = viz_data["raw_tensor_data"] + + # Precompute C values for each Dot operation + precomputed_c_values = {} + for uuid, op_data in raw_tensor_data.items(): + if "input_data" in op_data and "other_data" in op_data: + precomputed_c_values[uuid] = precompute_c_values(op_data) + + df = pd.DataFrame(analysis_data, columns=["Metric", "Value"]) + analysis_with_tooltip = get_tooltip_data(df) + global_data["analysis"] = analysis_with_tooltip + + +@app.route("/") +def index(): + return render_template("index.html") + + +@app.route("/api/data") +def get_data(): + global global_data + if global_data is None: + update_global_data() + return jsonify(global_data) + + +@app.route("/api/update_data") +def update_data(): + update_global_data() + return jsonify({"status": "Data updated successfully"}) + + +@app.route("/api/setop", methods=["POST"]) +def set_current_op(): + global current_fullscreen_op + data = request.json + current_fullscreen_op = data.get("uuid") + return jsonify( + {"status": "Current op set successfully", "uuid": current_fullscreen_op} + ) + + +@app.route("/api/getValue", methods=["POST"]) +def get_value(): + global raw_tensor_data, precomputed_c_values, current_fullscreen_op + print(current_fullscreen_op) + data = request.json + uuid = data.get("uuid") + matrix_name = data.get("matrixName") + row = data.get("row") + col = data.get("col") + + if uuid not in raw_tensor_data: + return jsonify({"error": "Operation not found"}), 404 + + op_data = raw_tensor_data[uuid] + + if matrix_name == "A": + value = ( + op_data["input_data"][row, col].item() if "input_data" in op_data else None + ) + return jsonify({"value": value}) + elif matrix_name == "B": + value = ( + op_data["other_data"][row, col].item() if "other_data" in op_data else None + ) + return jsonify({"value": value}) + elif matrix_name == "C": + current_step = data.get("currentStep", 0) + + if uuid not in precomputed_c_values: + return jsonify({"error": "Precomputed values not found"}), 404 + + precomputed = precomputed_c_values[uuid] + current_value = precomputed[(row, col)][current_step] + + return jsonify( + { + "value": current_value, + } + ) + else: + return jsonify({"error": "Invalid matrix name"}), 400 + + +@app.route("/api/getLoadValue", methods=["POST"]) +def get_load_value(): + global raw_tensor_data, current_fullscreen_op + + data = request.json + uuid = data.get("uuid") + x = data.get("x") + y = data.get("y") + z = data.get("z") + print(x, y, z) + if uuid is None or uuid not in raw_tensor_data: + return jsonify({"error": "Operation not found"}), 404 + + op_data = raw_tensor_data[uuid] + + if "global_tensor" in op_data and ( + x is not None and y is not None and z is not None + ): + try: + value = 0.0 + if op_data["dims"] == 3: + value = op_data["global_tensor"][x, y, z].item() + elif op_data["dims"] == 2: + value = op_data["global_tensor"][x, y].item() + elif op_data["dims"] == 1: + value = op_data["global_tensor"][x].item() + + return jsonify({"value": value}) + except IndexError: + return jsonify({"error": "Coordinates out of bounds"}), 200 + else: + return jsonify({"error": "Global tensor data not found"}), 200 + + +def run_flask_with_cloudflared(): + cloudflared_port = 8000 # You can change this port if needed + tunnel_url = _run_cloudflared(cloudflared_port, 8001) # not too important + print(f"Cloudflare tunnel URL: {tunnel_url}") + app.run(port=cloudflared_port) + + +def launch(share=True): + print("Launching Triton viz tool") + if share: + flask_thread = threading.Thread(target=run_flask_with_cloudflared) + flask_thread.start() + + # Wait for the server to start + time.sleep(5) + + # Try to get the tunnel URL by making a request to the local server + try: + response = requests.get("http://localhost:8000") + print(f"Your app is now available at: {response.url}") + except requests.exceptions.RequestException: + print("Please wait for URL:") + else: + app.run(port=5001) + + +# This function can be called to stop the Flask server if needed +def stop_server(flask_thread): + # Implement a way to stop the Flask server + pass diff --git a/triton_viz/interpreter.py b/triton_viz/interpreter.py index a572e1e..953d56b 100644 --- a/triton_viz/interpreter.py +++ b/triton_viz/interpreter.py @@ -44,6 +44,19 @@ def _unpatch_lang(): class RecordBuilder: + def to_dict(self): + """Convert the recorded data to a dictionary format for JSON serialization.""" + return { + "launches": [ + { + "grid": launch.grid, + "tensors": [tensor.__dict__ for tensor in launch.tensors], + "records": [record.__dict__ for record in launch.records], + } + for launch in self._launches + ] + } + def __init__(self) -> None: self.reset() @@ -136,7 +149,9 @@ def _check_storage_contiguous(tensor): def _grid_executor_call(self, *args_dev, **kwargs): # Removes reserved keywords from kwargs + # kwargs has src_map and src src is the raw straight spurce code, src_map is a mapping between source code line and an op kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} + if kwargs.pop("warmup", False): return args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) @@ -162,6 +177,7 @@ def _grid_executor_call(self, *args_dev, **kwargs): arg.stride(), arg.shape, arg.element_size(), + arg, ) ) call_args[name] = ret @@ -291,15 +307,31 @@ def wrapper(lhs, rhs, op): def _create_dot(fn): @wraps(fn) - def wrapper(a, b, d, allow_tf32, maxNumImpreciseAcc): - ret = fn(a, b, d, allow_tf32, maxNumImpreciseAcc) + def wrapper(a, b, c, allow_tf32, max_num_imprecise_acc): dot_record = Dot( - input_shape=(a.data.shape, b.data.shape), - other_shape=d.data.shape, - output_shape=ret.data.shape, + input_shape=a.data.shape, + other_shape=b.data.shape, + output_shape=c.data.shape, + input_data=a.data.tolist(), + other_data=b.data.tolist(), ) record_builder.add_record(dot_record) - return ret + + def capture_intermediate(row, col, result): + dot_record.update_intermediate(row, col, float(result)) + + # Modify the original function to call capture_intermediate at each step + def modified_fn(a, b, c, allow_tf32, max_num_imprecise_acc): + for i in range(a.data.shape[0]): + for j in range(b.data.shape[1]): + A_row = a.data[i, :] + B_column = b.data[:, j] + result = np.dot(A_row, B_column) + capture_intermediate(i, j, result) + c.data[i, j] = result + + modified_fn(a, b, c, allow_tf32, max_num_imprecise_acc) + return c return wrapper @@ -377,3 +409,8 @@ def patch(): interpreter_builder.binary_op = old_binary_op interpreter_builder.create_dot = old_create_dot interpreter_builder.create_masked_store = old_create_masked_store + + +def get_recorded_data(): + """Return the recorded data in a format suitable for JSON serialization.""" + return record_builder.to_dict() diff --git a/triton_viz/static/gridblock.js b/triton_viz/static/gridblock.js new file mode 100644 index 0000000..6c999cc --- /dev/null +++ b/triton_viz/static/gridblock.js @@ -0,0 +1,241 @@ +import { createMatMulVisualization } from './matmul.js'; +import { createLoadVisualization } from './load.js'; +import { createStoreVisualization } from './store.js'; + +export class GridBlock { + constructor(x, y, width, height, gridX, gridY, gridZ, blockData, onClose, containerElement, canvas, drawFunction) { + this.rect = { x, y, width, height }; + this.gridPosition = { x: gridX, y: gridY, z: gridZ }; + this.blockData = blockData; + this.isHovered = false; + this.visualizationContainer = null; + this.visualizationCleanupFunction = null; + this.onClose = onClose; + this.containerElement = containerElement; + this.canvas = canvas; + this.drawFunction = drawFunction; + this.isDetailedViewVisible = false; + this.contentArea = null; + } + + + draw(ctx) { + // Draw background + ctx.fillStyle = this.isHovered ? '#4a4a4a' : '#323232'; + ctx.fillRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + + // Draw border + ctx.strokeStyle = '#000000'; + ctx.lineWidth = 1; + ctx.strokeRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + + // Draw grid position + ctx.fillStyle = '#c8c8c8'; + ctx.font = '12px Arial'; + ctx.textAlign = 'center'; + ctx.textBaseline = 'top'; + const posText = `${this.gridPosition.x},${this.gridPosition.y},${this.gridPosition.z}`; + ctx.fillText(posText, this.rect.x + this.rect.width / 2, this.rect.y + 2); + + // Draw operation types + ctx.font = '10px Arial'; + ctx.textAlign = 'left'; + ctx.textBaseline = 'top'; + this.blockData.forEach((op, index) => { + ctx.fillText(op.type, this.rect.x + 5, this.rect.y + 20 + index * 15); + }); + } + + isPointInside(x, y) { + return x >= this.rect.x && x <= this.rect.x + this.rect.width && + y >= this.rect.y && y <= this.rect.y + this.rect.height; + } + + handleMouseMove(x, y) { + const wasHovered = this.isHovered; + this.isHovered = this.isPointInside(x, y); + return wasHovered !== this.isHovered; // Return true if hover state changed + } + + + showDetailedView() { + if (this.isDetailedViewVisible) return; + + this.visualizationContainer = this.createVisualizationContainer(); + document.body.appendChild(this.visualizationContainer); + + const title = this.createTitle(); + const headerBar = this.createHeaderBar(); + this.contentArea = this.createContentArea(); + + this.visualizationContainer.appendChild(title); + this.visualizationContainer.appendChild(headerBar); + this.visualizationContainer.appendChild(this.contentArea); + + const closeButton = this.createCloseButton(); + this.visualizationContainer.appendChild(closeButton); + + this.isDetailedViewVisible = true; + this.canvas.style.display = 'none'; + this.containerElement.style.display = 'block'; + + // Display the first operation visualization after the content area is added to the DOM + if (this.blockData.length > 0) { + this.displayOpVisualization(this.blockData[0]); + } + } + + createVisualizationContainer() { + const container = document.createElement('div'); + Object.assign(container.style, { + position: 'fixed', + top: '0', + left: '0', + width: '100vw', + height: '100vh', + backgroundColor: '#1e1e28', + zIndex: '1000', + display: 'flex', + flexDirection: 'column', + color: '#ffffff' + }); + return container; + } + + createTitle() { + const title = document.createElement('h2'); + title.textContent = `Operations at (${this.gridPosition.x}, ${this.gridPosition.y}, ${this.gridPosition.z})`; + title.style.textAlign = 'center'; + title.style.margin = '10px 0'; + return title; + } + + createHeaderBar() { + const headerBar = document.createElement('div'); + Object.assign(headerBar.style, { + display: 'flex', + flexDirection: 'row', + backgroundColor: '#333', + padding: '5px', + overflowX: 'auto' + }); + + let currentSelectedTab = null; + this.blockData.forEach((op, index) => { + const opTab = this.createOperationTab(op, index === 0); + opTab.addEventListener('click', () => this.handleTabClick(opTab, op, currentSelectedTab)); + headerBar.appendChild(opTab); + if (index === 0) currentSelectedTab = opTab; + }); + + return headerBar; + } + + createOperationTab(op, isFirst) { + const opTab = document.createElement('button'); + opTab.textContent = op.type; + Object.assign(opTab.style, { + flex: '0 0 auto', + marginRight: '5px', + backgroundColor: isFirst ? '#555' : '#333', + color: '#fff', + border: 'none', + padding: '10px', + cursor: 'pointer' + }); + return opTab; + } + + handleTabClick(clickedTab, op, currentSelectedTab) { + if (currentSelectedTab) currentSelectedTab.style.backgroundColor = '#333'; + clickedTab.style.backgroundColor = '#555'; + this.displayOpVisualization(op); + } + + createContentArea() { + const contentArea = document.createElement('div'); + Object.assign(contentArea.style, { + flex: '1', + padding: '10px', + overflow: 'hidden', + position: 'relative' + }); + + if (this.blockData.length === 0) { + const noDataMsg = document.createElement('p'); + noDataMsg.textContent = 'No operation data available'; + noDataMsg.style.textAlign = 'center'; + contentArea.appendChild(noDataMsg); + } + + return contentArea; + } + displayOpVisualization(op) { + if (!this.contentArea) { + console.error('Content area is not initialized'); + return; + } + + if (this.visualizationCleanupFunction) { + this.visualizationCleanupFunction(); + this.visualizationCleanupFunction = null; + } + + this.contentArea.innerHTML = ''; + + switch (op.type) { + case 'Dot': + this.visualizationCleanupFunction = createMatMulVisualization(this.contentArea, op); + break; + case 'Load': + this.visualizationCleanupFunction = createLoadVisualization(this.contentArea, op); + break; + case 'Store': + this.visualizationCleanupFunction = createStoreVisualization(this.contentArea, op); + break; + default: + const unsupportedMsg = document.createElement('p'); + unsupportedMsg.textContent = `Visualization not supported for ${op.type} operation`; + unsupportedMsg.style.textAlign = 'center'; + this.contentArea.appendChild(unsupportedMsg); + } + } + + + createCloseButton() { + const closeButton = document.createElement('button'); + closeButton.textContent = 'Close'; + Object.assign(closeButton.style, { + position: 'fixed', + top: '10px', + right: '10px', + zIndex: '1001' + }); + closeButton.addEventListener('click', () => this.hideDetailedView()); + return closeButton; + } + + hideDetailedView() { + if (!this.isDetailedViewVisible) return; + + if (this.visualizationCleanupFunction) { + this.visualizationCleanupFunction(); + this.visualizationCleanupFunction = null; + } + + if (this.visualizationContainer) { + document.body.removeChild(this.visualizationContainer); + this.visualizationContainer = null; + } + + this.isDetailedViewVisible = false; + + if (this.onClose) { + this.onClose(); + } + + this.canvas.style.display = 'block'; + this.containerElement.style.display = 'none'; + this.drawFunction(); + } +} diff --git a/triton_viz/static/images/accelerator-memory-hierarchy.png b/triton_viz/static/images/accelerator-memory-hierarchy.png new file mode 100644 index 0000000..3ca48d1 Binary files /dev/null and b/triton_viz/static/images/accelerator-memory-hierarchy.png differ diff --git a/triton_viz/static/images/tensor-slicing-diagram.png b/triton_viz/static/images/tensor-slicing-diagram.png new file mode 100644 index 0000000..d6a300f Binary files /dev/null and b/triton_viz/static/images/tensor-slicing-diagram.png differ diff --git a/triton_viz/static/infoPopup.js b/triton_viz/static/infoPopup.js new file mode 100644 index 0000000..0693a77 --- /dev/null +++ b/triton_viz/static/infoPopup.js @@ -0,0 +1,211 @@ +const infoContent = [ + { + title: "Visualization Tool Overview", + content: ` +

This tool allows you to explore and visualize tensor operations.

+ + +

Need a bit of a tutorial?

+

If you're new to Triton or need to brush up on your skills, check out Triton Puzzles. It's a collection of Triton problems designed to get you up to speed:

+

Triton Puzzles on GitHub

+

These puzzles provide hands-on experience with Triton concepts, which can help you better understand the visualizations in this tool.

+ ` + }, + { + title: "Using the Sliders", + content: ` +

The sliders on the right side of the screen allow you to filter the grid blocks:

+ +

Set a slider to -1 to show all blocks along that dimension.

+ ` + }, + { + title: "Detailed View", + content: ` +

Clicking on a grid block will open a detailed view of the operations for that block.

+

In the detailed view, you can:

+ + + ` + }, + { + title: "Tensor Slicing in Triton", + content: ` +

Triton operates on slices of tensors rather than entire tensors. This is a key concept in understanding how Triton kernels process data efficiently.

+ +

Global Tensor vs Slice Tensor

+ + +

Loading Tensor Slices

+

Inside Triton kernels, tl.load is used to load slices of the global tensor. For example:

+
a = tl.load(a_ptrs, mask=offs_k[None, :])
+

This operation loads only the necessary slice of data, optimizing memory access and computation.

+ +

Key Points

+ + + Tensor Slicing Diagram +

Diagram showing how a global PyTorch tensor is sliced for processing in Triton kernels.

+ ` + }, + { + title: "Memory Hierarchy in AI Accelerators and Triton Operations", + content: ` +

Understanding the memory hierarchy in AI accelerators is crucial for optimizing Triton kernels. While this example uses a GPU, it's important to note that Triton is designed for various AI accelerators, not just GPUs.

+ +

Memory Types in AI Accelerators

+ + +

Memory Access in Triton

+

Triton provides mechanisms to efficiently move data between these memory types across various accelerators:

+ + +

* It's important to note that Triton doesn't always transfer tensor slices to shared memory. This operation is performed only when it makes sense for performance optimization.

+ +

Key Points

+ + + AI Accelerator Memory Hierarchy +

Diagram showing the memory hierarchy of an RTX 4090 GPU as an example of AI accelerator architecture. Global memory (magenta) and on-chip memory including shared memory* (red) are highlighted. Note that this structure is similar across many AI accelerators.

+ ` + } +]; + +let currentPage = 0; + +export function createInfoPopup() { + const infoPopup = document.createElement('div'); + infoPopup.style.display = 'none'; + infoPopup.style.position = 'fixed'; + infoPopup.style.top = '5%'; + infoPopup.style.left = '5%'; + infoPopup.style.width = '85%'; + infoPopup.style.height = '75%'; + infoPopup.style.backgroundColor = '#1a1a1a'; + infoPopup.style.color = '#fff'; + infoPopup.style.padding = '30px'; + infoPopup.style.borderRadius = '15px'; + infoPopup.style.boxShadow = '0 4px 6px rgba(0, 0, 0, 0.5)'; + infoPopup.style.zIndex = '1000'; + infoPopup.style.overflow = 'auto'; + + const closeButton = createButton('×', () => { infoPopup.style.display = 'none'; }); + closeButton.style.position = 'absolute'; + closeButton.style.top = '20px'; + closeButton.style.right = '20px'; + closeButton.style.fontSize = '36px'; + closeButton.style.color = '#fff'; + + const content = document.createElement('div'); + content.id = 'info-content'; + content.style.fontSize = '18px'; + content.style.lineHeight = '1.6'; + + const navigation = document.createElement('div'); + navigation.style.display = 'flex'; + navigation.style.justifyContent = 'space-between'; + navigation.style.marginTop = '30px'; + + const prevButton = createNavButton('← Previous', () => navigatePage(-1)); + const nextButton = createNavButton('Next →', () => navigatePage(1)); + + navigation.appendChild(prevButton); + navigation.appendChild(nextButton); + + infoPopup.appendChild(closeButton); + infoPopup.appendChild(content); + infoPopup.appendChild(navigation); + + document.body.appendChild(infoPopup); + + updateContent(); + + return infoPopup; +} + +function createButton(text, onClick) { + const button = document.createElement('button'); + button.innerHTML = text; + button.style.border = 'none'; + button.style.background = 'none'; + button.style.cursor = 'pointer'; + button.style.color = '#fff'; + button.onclick = onClick; + return button; +} + +function createNavButton(text, onClick) { + const button = createButton(text, onClick); + button.style.fontSize = '20px'; + button.style.padding = '15px 25px'; + button.style.backgroundColor = '#333'; + button.style.borderRadius = '8px'; + button.style.transition = 'background-color 0.3s'; + button.onmouseover = () => { button.style.backgroundColor = '#444'; }; + button.onmouseout = () => { button.style.backgroundColor = '#333'; }; + return button; +} + +function navigatePage(direction) { + currentPage += direction; + if (currentPage < 0) currentPage = infoContent.length - 1; + if (currentPage >= infoContent.length) currentPage = 0; + updateContent(); +} + +function updateContent() { + const contentDiv = document.getElementById('info-content'); + const pageContent = infoContent[currentPage]; + contentDiv.innerHTML = ` +

${pageContent.title}

+ ${pageContent.content} + `; + + // Ensure all text in the content is white and styled appropriately + contentDiv.querySelectorAll('p, li').forEach(element => { + element.style.color = '#fff'; + element.style.marginBottom = '15px'; + }); + + contentDiv.querySelectorAll('ul').forEach(element => { + element.style.paddingLeft = '30px'; + }); +} + +export function showInfoPopup(infoPopup) { + infoPopup.style.display = 'block'; +} \ No newline at end of file diff --git a/triton_viz/static/load.js b/triton_viz/static/load.js new file mode 100644 index 0000000..8574cf6 --- /dev/null +++ b/triton_viz/static/load.js @@ -0,0 +1,221 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; +import { + setupScene, + setupGeometries, + createTensor, + calculateTensorSize, + updateCubeColor, + setupCamera, + setupEventListeners, + cameraControls +} from './load_utils.js'; + +export function createLoadVisualization(containerElement, op) { + + console.log(op.uuid); + fetch('/api/setop', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid: op.uuid }), + }) + .then(response => response.json()) + .then(data => console.log('Set current op:', data)) + .catch((error) => console.error('Error:', error)); + + let currentStep = 0; + let frame = 0; + let isPaused = false; + + const sideMenu = createSideMenu(containerElement); + let hoveredCube = null; + + const COLOR_GLOBAL = new THREE.Color(0.2, 0.2, 0.2); // Dark Gray + const COLOR_SLICE = new THREE.Color(0.0, 0.7, 1.0); // Cyan (starting color for global slice) + const COLOR_LEFT_SLICE = new THREE.Color(1.0, 0.0, 1.0); // Magenta (starting color for left slice) + const COLOR_LOADED = new THREE.Color(1.0, 0.8, 0.0); // Gold (final color for both slices) + const COLOR_BACKGROUND = new THREE.Color(0.0, 0.0, 0.0); // Black + + const { scene, camera, renderer } = setupScene(containerElement, COLOR_BACKGROUND); + const { cubeGeometry, edgesGeometry, lineMaterial } = setupGeometries(); + + const globalTensor = createTensor(op.global_shape, op.global_coords, COLOR_GLOBAL, 'Global', cubeGeometry, edgesGeometry, lineMaterial); + const sliceTensor = createTensor(op.slice_shape, op.slice_coords, COLOR_LEFT_SLICE, 'Slice', cubeGeometry, edgesGeometry, lineMaterial); + + // Position slice tensor + const globalSize = calculateTensorSize(op.global_shape); + sliceTensor.position.set(globalSize.x + 5, 0, 0); // Adjusted tensor spacing + + scene.add(globalTensor); + scene.add(sliceTensor); + + addLabels(scene, globalTensor, sliceTensor); + setupCamera(scene, camera); + + const totalFrames = op.global_coords.length * 2 + 30; + + const raycaster = new THREE.Raycaster(); + const mouse = new THREE.Vector2(); + + const onKeyDown = cameraControls(camera, new THREE.Euler(0, 0, 0, 'YXZ')); + setupEventListeners(containerElement, camera, renderer, onMouseMove, onKeyDown); + animate(); + + async function onMouseMove(event) { + mouse.x = (event.clientX / containerElement.clientWidth) * 2 - 1; + mouse.y = -(event.clientY / containerElement.clientHeight) * 2 + 1; + + raycaster.setFromCamera(mouse, camera); + + const allTensorChildren = [ + ...globalTensor.children, + ...sliceTensor.children + ]; + + const intersects = raycaster.intersectObjects(allTensorChildren, true); + + if (hoveredCube) { + hoveredCube.getObjectByName('hoverOutline').visible = false; + hoveredCube = null; + } + + if (intersects.length > 0) { + hoveredCube = intersects[0].object; + while (hoveredCube && !hoveredCube.tensorName) { + hoveredCube = hoveredCube.parent; + } + + if (hoveredCube) { + const hoverOutline = hoveredCube.getObjectByName('hoverOutline'); + if (hoverOutline) { + hoverOutline.visible = true; + } + + updateSideMenu(hoveredCube.tensorName, hoveredCube.tensor0, hoveredCube.tensor1, hoveredCube.tensor2, undefined); + + const res = await getElementValue(hoveredCube.tensorName, hoveredCube.tensor0, hoveredCube.tensor1, hoveredCube.tensor2); + + updateSideMenu(hoveredCube.tensorName, hoveredCube.tensor0, hoveredCube.tensor1, hoveredCube.tensor2, res.value); + + console.log(`Value: ${res.value}`); + } + } else { + updateSideMenu(null); + } + } + + function animate() { + requestAnimationFrame(animate); + + if (!isPaused && frame < totalFrames) { + const index = Math.floor(frame / 2); + const factor = (frame % 2) / 1.0; + + if (index < op.global_coords.length) { + const globalCoord = op.global_coords[index]; + const sliceCoord = op.slice_coords[index]; + + updateCubeColor(globalTensor, globalCoord, COLOR_GLOBAL, COLOR_SLICE, factor); + updateCubeColor(sliceTensor, sliceCoord, COLOR_LEFT_SLICE, COLOR_LOADED, factor); + + highlightCurrentOperation(globalTensor, globalCoord, sliceTensor, sliceCoord); + updateInfoPanel(globalCoord, sliceCoord, index); + } + + frame++; + } + + renderer.render(scene, camera); + } + + function highlightCurrentOperation(globalTensor, globalCoord, sliceTensor, sliceCoord) { + globalTensor.children.forEach(cube => cube.material.emissive.setHex(0x000000)); + sliceTensor.children.forEach(cube => cube.material.emissive.setHex(0x000000)); + + const globalCube = globalTensor.children.find(c => + c.tensor0 === globalCoord[0] && c.tensor1 === globalCoord[1] && c.tensor2 === globalCoord[2] + ); + const sliceCube = sliceTensor.children.find(c => + c.tensor0 === sliceCoord[0] && c.tensor1 === sliceCoord[1] && c.tensor2 === sliceCoord[2] + ); + + if (globalCube) globalCube.material.emissive.setHex(0x444444); + if (sliceCube) sliceCube.material.emissive.setHex(0x444444); + } + + async function getElementValue(tensorName, x, y, z) { + let uuid = op.uuid; + const response = await fetch('/api/getLoadValue', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid, tensorName, x, y, z }), + }); + return await response.json(); + } + + function updateSideMenu(tensorName, x, y, z, value) { + if (!tensorName) { + sideMenu.innerHTML = ''; + return; + } + + let dims = tensorName === 'Global' ? op.global_shape : op.slice_shape; + sideMenu.innerHTML = ` +

${tensorName} Tensor

+

X: ${x + 1}

+

Y: ${y + 1}

+

Z: ${z + 1}

+

Dimensions: ${dims.join(' x ')}

+

Value: ${value !== undefined ? value : 'Loading...'}

+ `; + } + + function updateInfoPanel(globalCoord, sliceCoord, index) { + sideMenu.innerHTML = ` +

Current Operation

+

Global Coords: (${globalCoord.join(', ')})

+

Slice Coords: (${sliceCoord.join(', ')})

+

Progress: ${index + 1}/${op.global_coords.length}

+ `; + } + + function createSideMenu(container) { + const menu = document.createElement('div'); + menu.style.position = 'absolute'; + menu.style.top = '10px'; + menu.style.right = '10px'; + menu.style.width = '200px'; + menu.style.padding = '10px'; + menu.style.backgroundColor = 'rgba(0, 0, 0, 0.7)'; + menu.style.color = 'white'; + menu.style.fontFamily = 'Arial, sans-serif'; + menu.style.fontSize = '14px'; + menu.style.borderRadius = '5px'; + container.appendChild(menu); + return menu; + } + + function addLabels(scene, globalTensor, sliceTensor) { + addLabel(scene, "Global Tensor", globalTensor.position); + addLabel(scene, "Slice Tensor", sliceTensor.position); + } + + function addLabel(scene, text, position) { + const canvas = document.createElement('canvas'); + const context = canvas.getContext('2d'); + context.font = 'Bold 24px Arial'; + context.fillStyle = 'white'; + context.fillText(text, 0, 24); + + const texture = new THREE.CanvasTexture(canvas); + const material = new THREE.SpriteMaterial({ map: texture }); + const sprite = new THREE.Sprite(material); + sprite.position.set(position.x, position.y + 2, position.z); + sprite.scale.set(4, 2, 1); + scene.add(sprite); + } + +} diff --git a/triton_viz/static/load_utils.js b/triton_viz/static/load_utils.js new file mode 100644 index 0000000..2729d04 --- /dev/null +++ b/triton_viz/static/load_utils.js @@ -0,0 +1,200 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; + +export const CUBE_SIZE = 0.2; +export const GAP = 0.05; +export const COLOR_HOVER = new THREE.Color(1.0, 1.0, 0.0); // Yellow for hover effect +export const COLOR_EDGE = new THREE.Color(0.5, 0.5, 0.5); // Gray (for cube edges) + +const COLOR_SLICE = new THREE.Color(0.0, 0.7, 1.0); // Cyan (starting color for global slice) + +export function setupScene(container, backgroundColor = 0x000000) { + const scene = new THREE.Scene(); + scene.background = new THREE.Color(backgroundColor); + const camera = new THREE.PerspectiveCamera(45, container.clientWidth / container.clientHeight, 0.1, 1000); + const renderer = new THREE.WebGLRenderer({ antialias: true }); + renderer.setSize(container.clientWidth, container.clientHeight); + container.appendChild(renderer.domElement); + + const ambientLight = new THREE.AmbientLight(0xffffff, 0.5); + scene.add(ambientLight); + const directionalLight = new THREE.DirectionalLight(0xffffff, 0.5); + directionalLight.position.set(10, 10, 10); + scene.add(directionalLight); + + return { scene, camera, renderer }; +} + +export function setupGeometries() { + const cubeGeometry = new THREE.BoxGeometry(CUBE_SIZE, CUBE_SIZE, CUBE_SIZE); + const edgesGeometry = new THREE.EdgesGeometry(cubeGeometry); + const lineMaterial = new THREE.LineBasicMaterial({ color: COLOR_EDGE }); + return { cubeGeometry, edgesGeometry, lineMaterial }; +} + +export function createCube(color, tensorName, x, y, z, cubeGeometry, edgesGeometry, lineMaterial) { + const cubeMaterial = new THREE.MeshPhongMaterial({ color: color }); + const cube = new THREE.Mesh(cubeGeometry, cubeMaterial); + const edges = new THREE.LineSegments(edgesGeometry, lineMaterial); + cube.add(edges); + + const hoverGeometry = new THREE.BoxGeometry(CUBE_SIZE * 1.05, CUBE_SIZE * 1.05, CUBE_SIZE * 1.05); + const hoverEdgesGeometry = new THREE.EdgesGeometry(hoverGeometry); + const hoverOutline = new THREE.LineSegments(hoverEdgesGeometry, new THREE.LineBasicMaterial({ color: COLOR_HOVER })); + hoverOutline.visible = false; + hoverOutline.name = 'hoverOutline'; + cube.add(hoverOutline); + + // Add custom properties to store tensor coordinates + cube.userData.tensor0 = z; + cube.userData.tensor1 = y; + cube.userData.tensor2 = x; + cube.userData.tensorName = tensorName; + + cube.name = `${tensorName}_cube_${x}_${y}_${z}`; + + return cube; +} + +export function createTensor(shape, coords, color, tensorName, cubeGeometry, edgesGeometry, lineMaterial) { + console.log(`Creating ${tensorName} tensor:`, shape, coords); + const tensor = new THREE.Group(); + let [width, height, depth] = shape; + depth = depth || 1; + height = height || 1; + + if (tensorName === 'Global') { + console.log(`Creating global tensor with dimensions: ${width}x${height}x${depth}`); + for (let z = 0; z < depth; z++) { + for (let y = 0; y < height; y++) { + for (let x = 0; x < width; x++) { + const cube = createCube(color, tensorName, x, y, z, cubeGeometry, edgesGeometry, lineMaterial); + cube.position.set( + x * (CUBE_SIZE + GAP), + -y * (CUBE_SIZE + GAP), + -z * (CUBE_SIZE + GAP) + ); + tensor.add(cube); + } + } + } + + console.log(`Highlighting ${coords.length} coordinates in global tensor`); + coords.forEach(([x, y, z]) => { + const cube = tensor.children.find(c => + c.userData.tensor0 === x && c.userData.tensor1 === y && c.userData.tensor2 === z + ); + if (cube) { + cube.material.color.set(COLOR_SLICE); + console.log(`Highlighted cube at (${x}, ${y}, ${z})`); + } else { + console.warn(`Could not find cube at (${x}, ${y}, ${z})`); + } + }); + } else { + console.log(`Creating slice tensor with ${coords.length} coordinates`); + coords.forEach(([x, y, z]) => { + const cube = createCube(color, tensorName, x, y, z, cubeGeometry, edgesGeometry, lineMaterial); + cube.position.set( + x * (CUBE_SIZE + GAP), + -y * (CUBE_SIZE + GAP), + -z * (CUBE_SIZE + GAP) + ); + tensor.add(cube); + }); + } + + console.log(`Created ${tensorName} tensor with ${tensor.children.length} cubes`); + return tensor; +} + +export function calculateTensorSize(shape) { + const [width, height, depth] = shape; + return new THREE.Vector3( + width * (CUBE_SIZE + GAP), + height * (CUBE_SIZE + GAP), + depth * (CUBE_SIZE + GAP) + ); +} + +export function interpolateColor(color1, color2, factor) { + return new THREE.Color().lerpColors(color1, color2, factor); +} + +export function updateCubeColor(tensor, coord, startColor, endColor, factor) { + const cube = tensor.children.find(c => + c.tensor0 === coord[0] && c.tensor1 === coord[1] && c.tensor2 === coord[2] + ); + if (cube) { + cube.material.color.copy(interpolateColor(startColor, endColor, factor)); + } +} + +export function setupCamera(scene, camera) { + const box = new THREE.Box3().setFromObject(scene); + const center = box.getCenter(new THREE.Vector3()); + const size = box.getSize(new THREE.Vector3()); + const maxDim = Math.max(size.x, size.y, size.z); + const fov = camera.fov * (Math.PI / 180); + let cameraZ = Math.abs(maxDim / 2 / Math.tan(fov / 2)); + cameraZ *= 1.5; + + camera.position.set(center.x, center.y, center.z + cameraZ); + camera.lookAt(center); + + return { center, cameraZ }; +} + +export function setupEventListeners(containerElement, camera, renderer, onMouseMove, onKeyDown) { + window.addEventListener('resize', () => { + camera.aspect = containerElement.clientWidth / containerElement.clientHeight; + camera.updateProjectionMatrix(); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + }); + containerElement.addEventListener('mousemove', onMouseMove); + window.addEventListener('keydown', onKeyDown); +} + +export function cameraControls(camera, cameraRotation) { + const PAN_SPEED = 0.1; + const TILT_SPEED = 0.02; + const ZOOM_SPEED = 0.5; + + return function onKeyDown(event) { + switch (event.key.toLowerCase()) { + case 'w': + camera.position.y += PAN_SPEED; + break; + case 's': + camera.position.y -= PAN_SPEED; + break; + case 'a': + camera.position.x -= PAN_SPEED; + break; + case 'd': + camera.position.x += PAN_SPEED; + break; + case 'arrowup': + cameraRotation.x -= TILT_SPEED; + break; + case 'arrowdown': + cameraRotation.x += TILT_SPEED; + break; + case 'arrowleft': + cameraRotation.y -= TILT_SPEED; + break; + case 'arrowright': + cameraRotation.y += TILT_SPEED; + break; + case 'o': + camera.position.z += ZOOM_SPEED; + break; + case 'p': + camera.position.z -= ZOOM_SPEED; + break; + case ' ': + break; + } + camera.setRotationFromEuler(cameraRotation); + camera.updateProjectionMatrix(); + }; +} diff --git a/triton_viz/static/matmul.js b/triton_viz/static/matmul.js new file mode 100644 index 0000000..d87d78b --- /dev/null +++ b/triton_viz/static/matmul.js @@ -0,0 +1,360 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; + +export function createMatMulVisualization(containerElement, op) { + const { input_shape, other_shape, output_shape } = op; + console.log(op.uuid) + fetch('/api/setop', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid: op.uuid }), + }) + .then(response => response.json()) + .then(data => console.log('Set current op:', data)) + .catch((error) => console.error('Error:', error)); + let currentStep = 0; + const totalSteps = input_shape[1]; + let frame = 0; + + + const sideMenu = document.createElement('div'); + sideMenu.style.position = 'absolute'; + sideMenu.style.top = '10px'; + sideMenu.style.right = '10px'; + sideMenu.style.width = '200px'; + sideMenu.style.padding = '10px'; + sideMenu.style.backgroundColor = 'rgba(0, 0, 0, 0.7)'; + sideMenu.style.color = 'white'; + sideMenu.style.fontFamily = 'Arial, sans-serif'; + sideMenu.style.fontSize = '14px'; + sideMenu.style.borderRadius = '5px'; + containerElement.appendChild(sideMenu); + let hoveredCube = null; + + + + + async function getElementValue( matrixName, row, col) { + let uuid = op.uuid; + const response = await fetch('/api/getValue', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid, matrixName, row, col, currentStep }), + }); + return await response.json(); + } + + + function updateSideMenu(matrix, x, y) { + if (!matrix) { + sideMenu.innerHTML = ''; + return; + } + + let matrixName; + let dims; + if (matrix === matrixA) { + matrixName = 'A'; + dims = input_shape; + } else if (matrix === matrixB) { + matrixName = 'B'; + dims = other_shape; + } else if (matrix === matrixC) { + matrixName = 'C'; + dims = output_shape; + } else { + sideMenu.innerHTML = ''; + return; + } + console.log(matrixName, "x:", (x + 1), "y:", (y + 1)); + sideMenu.innerHTML = ` +

Matrix ${matrixName}

+

Row: ${y + 1}

+

Column: ${x + 1}

+

Dimensions: ${dims[0]} x ${dims[1]}

+ `; + } + + const raycaster = new THREE.Raycaster(); + const mouse = new THREE.Vector2(); + + + async function onMouseMove(event) { + mouse.x = (event.clientX / containerElement.clientWidth) * 2 - 1; + mouse.y = -(event.clientY / containerElement.clientHeight) * 2 + 1; + + raycaster.setFromCamera(mouse, camera); + + const allMatrixChildren = [ + ...(matrixA ? matrixA.children : []), + ...(matrixB ? matrixB.children : []), + ...(matrixC ? matrixC.children : []) + ]; + + const intersects = raycaster.intersectObjects(allMatrixChildren, true); + + if (hoveredCube) { + hoveredCube.getObjectByName('hoverOutline').visible = false; + hoveredCube = null; + } + + if (intersects.length > 0) { + // Find the actual cube (parent of the intersected object) + hoveredCube = intersects[0].object; + while (hoveredCube && !hoveredCube.matrixName) { + hoveredCube = hoveredCube.parent; + } + + if (hoveredCube) { + const hoverOutline = hoveredCube.getObjectByName('hoverOutline'); + if (hoverOutline) { + hoverOutline.visible = true; + } + const res = await getElementValue(hoveredCube.matrixName, hoveredCube.matrixRow, hoveredCube.matrixCol); + // Log the matrix name, row, and column of the hovered cube + console.log( + // `Matrix: ${hoveredCube.matrixName}, ` + + // `Row: ${hoveredCube.matrixRow + 1}, ` + + // `Column: ${hoveredCube.matrixCol + 1}`+ + `Value: ${res.value}` + ); + + updateSideMenu(hoveredCube.matrixName, hoveredCube.matrixRow, hoveredCube.matrixCol); + } + } else { + updateSideMenu(null); + } + } + + + + + + const CUBE_SIZE = 0.2; + const GAP = 0.05; + + const COLOR_A = new THREE.Color(0.53, 0.81, 0.98); + const COLOR_B = new THREE.Color(1.0, 0.65, 0.0); + const COLOR_C = new THREE.Color(1.0, 1.0, 1.0); + const COLOR_HIGHLIGHT = new THREE.Color(0.0, 0.0, 1.0); + const COLOR_FILLED = new THREE.Color(0.0, 0.0, 1.0); + const COLOR_BACKGROUND = new THREE.Color(0.0, 0.0, 0.0); + const COLOR_EDGE = new THREE.Color(0.3, 0.3, 0.3); + const COLOR_HOVER = new THREE.Color(1.0, 1.0, 0.0); + + const scene = new THREE.Scene(); + scene.background = COLOR_BACKGROUND; + const camera = new THREE.PerspectiveCamera(45, containerElement.clientWidth / containerElement.clientHeight, 0.1, 1000); + + const renderer = new THREE.WebGLRenderer({ antialias: true }); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + containerElement.appendChild(renderer.domElement); + + const ambientLight = new THREE.AmbientLight(0xffffff, 0.5); + scene.add(ambientLight); + const directionalLight = new THREE.DirectionalLight(0xffffff, 0.5); + directionalLight.position.set(10, 10, 10); + scene.add(directionalLight); + + const cubeGeometry = new THREE.BoxGeometry(CUBE_SIZE, CUBE_SIZE, CUBE_SIZE); + const edgesGeometry = new THREE.EdgesGeometry(cubeGeometry); + const lineMaterial = new THREE.LineBasicMaterial({ color: COLOR_EDGE }); + + function createCube(color, matrixName, i, j) { + const cubeMaterial = new THREE.MeshPhongMaterial({ color: color }); + const cube = new THREE.Mesh(cubeGeometry, cubeMaterial); + const edges = new THREE.LineSegments(edgesGeometry, lineMaterial); + cube.add(edges); + + const hoverGeometry = new THREE.BoxGeometry(CUBE_SIZE * 1.05, CUBE_SIZE * 1.05, CUBE_SIZE * 1.05); + const hoverEdgesGeometry = new THREE.EdgesGeometry(hoverGeometry); + const hoverOutline = new THREE.LineSegments(hoverEdgesGeometry, new THREE.LineBasicMaterial({ color: COLOR_HOVER })); + hoverOutline.visible = false; + hoverOutline.name = 'hoverOutline'; + cube.add(hoverOutline); + + cube.name = `${matrixName}_cube_${i}_${j}`; + cube.matrixName = matrixName; + cube.matrixRow = i; + cube.matrixCol = j; + + return cube; + } + + function createMatrix(dimensions, position, color, matrixName) { + const matrix = new THREE.Group(); + matrix.userData.dimensions = dimensions; + for (let i = 0; i < dimensions[0]; i++) { + for (let j = 0; j < dimensions[1]; j++) { + const cube = createCube(color, matrixName, i, j); + cube.position.set( + position.x + j * (CUBE_SIZE + GAP), + position.y - i * (CUBE_SIZE + GAP), + position.z + ); + matrix.add(cube); + } + } + return matrix; + } + + const matrixA = createMatrix(input_shape, new THREE.Vector3(-10, 10, 0), COLOR_A, 'A'); + const matrixB = createMatrix(other_shape, new THREE.Vector3(0, 10, 0), COLOR_B, 'B'); + const matrixC = createMatrix(output_shape, new THREE.Vector3(-5, -4, 0), COLOR_C, 'C'); + + scene.add(matrixA); + scene.add(matrixB); + scene.add(matrixC); + + const center = new THREE.Vector3(); + const size = new THREE.Vector3(); + const box = new THREE.Box3().setFromObject(scene); + box.getCenter(center); + box.getSize(size); + const maxDim = Math.max(size.x, size.y, size.z); + const fov = camera.fov * (Math.PI / 180); + let cameraZ = Math.abs(maxDim / 2 / Math.tan(fov / 2)); + cameraZ *= 1.5; + + camera.position.set(center.x, center.y, center.z + cameraZ); + camera.lookAt(center); + + let isPaused = false; + + const totalFrames = input_shape[0] * other_shape[1]; + + function highlightCubes(matrix, indices, highlightColor) { + indices.forEach(([i, j]) => { + if (i >= 0 && i < matrix.userData.dimensions[0] && j >= 0 && j < matrix.userData.dimensions[1]) { + const index = i * matrix.userData.dimensions[1] + j; + if (index < matrix.children.length) { + matrix.children[index].material.color.copy(highlightColor); + + } + } + }); + } + + function resetColors() { + matrixA.children.forEach(cube => cube.material.color.copy(COLOR_A)); + matrixB.children.forEach(cube => cube.material.color.copy(COLOR_B)); + } + + function animate() { + requestAnimationFrame(animate); + + if (!isPaused && frame < totalFrames) { + resetColors(); + + const row = Math.floor(frame / other_shape[1]); + const col = frame % other_shape[1]; + currentStep = frame % totalSteps + 1; + + const highlightA = Array.from({ length: input_shape[1] }, (_, i) => [row, i]); + const highlightB = Array.from({ length: other_shape[0] }, (_, i) => [i, col]); + const highlightC = [[row, col]]; + + highlightCubes(matrixA, highlightA, COLOR_HIGHLIGHT); + highlightCubes(matrixB, highlightB, COLOR_HIGHLIGHT); + highlightCubes(matrixC, highlightC, COLOR_FILLED); + + frame++; + } + + renderer.render(scene, camera); + } + + function onResize() { + camera.aspect = containerElement.clientWidth / containerElement.clientHeight; + camera.updateProjectionMatrix(); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + } + + const controlPanel = document.createElement('div'); + controlPanel.style.position = 'absolute'; + controlPanel.style.bottom = '10px'; + controlPanel.style.left = '10px'; + controlPanel.style.display = 'flex'; + controlPanel.style.gap = '10px'; + + const playPauseButton = document.createElement('button'); + playPauseButton.textContent = 'Play/Pause'; + playPauseButton.addEventListener('click', () => { + isPaused = !isPaused; + }); + + const resetButton = document.createElement('button'); + resetButton.textContent = 'Reset'; + resetButton.addEventListener('click', () => { + frame = 0; + resetColors(); + }); + + controlPanel.appendChild(playPauseButton); + controlPanel.appendChild(resetButton); + + containerElement.appendChild(controlPanel); + + const PAN_SPEED = 0.1; + const TILT_SPEED = 0.02; + const ZOOM_SPEED = 0.5; + let cameraRotation = new THREE.Euler(0, 0, 0, 'YXZ'); + + function onKeyDown(event) { + switch (event.key.toLowerCase()) { + case 'w': + camera.position.y += PAN_SPEED; + break; + case 's': + camera.position.y -= PAN_SPEED; + break; + case 'a': + camera.position.x -= PAN_SPEED; + break; + case 'd': + camera.position.x += PAN_SPEED; + break; + case 'arrowup': + cameraRotation.x -= TILT_SPEED; + break; + case 'arrowdown': + cameraRotation.x += TILT_SPEED; + break; + case 'arrowleft': + cameraRotation.y -= TILT_SPEED; + break; + case 'arrowright': + cameraRotation.y += TILT_SPEED; + break; + case 'o': + camera.position.z += ZOOM_SPEED; + break; + case 'p': + camera.position.z -= ZOOM_SPEED; + break; + } + camera.setRotationFromEuler(cameraRotation); + camera.updateProjectionMatrix(); + } + + window.addEventListener('resize', onResize); + window.addEventListener('keydown', onKeyDown); + containerElement.addEventListener('mousemove', onMouseMove); + + animate(); + + + function cleanup() { + window.removeEventListener('resize', onResize); + window.removeEventListener('keydown', onKeyDown); + containerElement.removeEventListener('mousemove', onMouseMove); + containerElement.innerHTML = ''; + renderer.dispose(); + scene.clear(); + } + + return cleanup; +} diff --git a/triton_viz/static/store-utils.js b/triton_viz/static/store-utils.js new file mode 100644 index 0000000..d87d78b --- /dev/null +++ b/triton_viz/static/store-utils.js @@ -0,0 +1,360 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; + +export function createMatMulVisualization(containerElement, op) { + const { input_shape, other_shape, output_shape } = op; + console.log(op.uuid) + fetch('/api/setop', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid: op.uuid }), + }) + .then(response => response.json()) + .then(data => console.log('Set current op:', data)) + .catch((error) => console.error('Error:', error)); + let currentStep = 0; + const totalSteps = input_shape[1]; + let frame = 0; + + + const sideMenu = document.createElement('div'); + sideMenu.style.position = 'absolute'; + sideMenu.style.top = '10px'; + sideMenu.style.right = '10px'; + sideMenu.style.width = '200px'; + sideMenu.style.padding = '10px'; + sideMenu.style.backgroundColor = 'rgba(0, 0, 0, 0.7)'; + sideMenu.style.color = 'white'; + sideMenu.style.fontFamily = 'Arial, sans-serif'; + sideMenu.style.fontSize = '14px'; + sideMenu.style.borderRadius = '5px'; + containerElement.appendChild(sideMenu); + let hoveredCube = null; + + + + + async function getElementValue( matrixName, row, col) { + let uuid = op.uuid; + const response = await fetch('/api/getValue', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid, matrixName, row, col, currentStep }), + }); + return await response.json(); + } + + + function updateSideMenu(matrix, x, y) { + if (!matrix) { + sideMenu.innerHTML = ''; + return; + } + + let matrixName; + let dims; + if (matrix === matrixA) { + matrixName = 'A'; + dims = input_shape; + } else if (matrix === matrixB) { + matrixName = 'B'; + dims = other_shape; + } else if (matrix === matrixC) { + matrixName = 'C'; + dims = output_shape; + } else { + sideMenu.innerHTML = ''; + return; + } + console.log(matrixName, "x:", (x + 1), "y:", (y + 1)); + sideMenu.innerHTML = ` +

Matrix ${matrixName}

+

Row: ${y + 1}

+

Column: ${x + 1}

+

Dimensions: ${dims[0]} x ${dims[1]}

+ `; + } + + const raycaster = new THREE.Raycaster(); + const mouse = new THREE.Vector2(); + + + async function onMouseMove(event) { + mouse.x = (event.clientX / containerElement.clientWidth) * 2 - 1; + mouse.y = -(event.clientY / containerElement.clientHeight) * 2 + 1; + + raycaster.setFromCamera(mouse, camera); + + const allMatrixChildren = [ + ...(matrixA ? matrixA.children : []), + ...(matrixB ? matrixB.children : []), + ...(matrixC ? matrixC.children : []) + ]; + + const intersects = raycaster.intersectObjects(allMatrixChildren, true); + + if (hoveredCube) { + hoveredCube.getObjectByName('hoverOutline').visible = false; + hoveredCube = null; + } + + if (intersects.length > 0) { + // Find the actual cube (parent of the intersected object) + hoveredCube = intersects[0].object; + while (hoveredCube && !hoveredCube.matrixName) { + hoveredCube = hoveredCube.parent; + } + + if (hoveredCube) { + const hoverOutline = hoveredCube.getObjectByName('hoverOutline'); + if (hoverOutline) { + hoverOutline.visible = true; + } + const res = await getElementValue(hoveredCube.matrixName, hoveredCube.matrixRow, hoveredCube.matrixCol); + // Log the matrix name, row, and column of the hovered cube + console.log( + // `Matrix: ${hoveredCube.matrixName}, ` + + // `Row: ${hoveredCube.matrixRow + 1}, ` + + // `Column: ${hoveredCube.matrixCol + 1}`+ + `Value: ${res.value}` + ); + + updateSideMenu(hoveredCube.matrixName, hoveredCube.matrixRow, hoveredCube.matrixCol); + } + } else { + updateSideMenu(null); + } + } + + + + + + const CUBE_SIZE = 0.2; + const GAP = 0.05; + + const COLOR_A = new THREE.Color(0.53, 0.81, 0.98); + const COLOR_B = new THREE.Color(1.0, 0.65, 0.0); + const COLOR_C = new THREE.Color(1.0, 1.0, 1.0); + const COLOR_HIGHLIGHT = new THREE.Color(0.0, 0.0, 1.0); + const COLOR_FILLED = new THREE.Color(0.0, 0.0, 1.0); + const COLOR_BACKGROUND = new THREE.Color(0.0, 0.0, 0.0); + const COLOR_EDGE = new THREE.Color(0.3, 0.3, 0.3); + const COLOR_HOVER = new THREE.Color(1.0, 1.0, 0.0); + + const scene = new THREE.Scene(); + scene.background = COLOR_BACKGROUND; + const camera = new THREE.PerspectiveCamera(45, containerElement.clientWidth / containerElement.clientHeight, 0.1, 1000); + + const renderer = new THREE.WebGLRenderer({ antialias: true }); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + containerElement.appendChild(renderer.domElement); + + const ambientLight = new THREE.AmbientLight(0xffffff, 0.5); + scene.add(ambientLight); + const directionalLight = new THREE.DirectionalLight(0xffffff, 0.5); + directionalLight.position.set(10, 10, 10); + scene.add(directionalLight); + + const cubeGeometry = new THREE.BoxGeometry(CUBE_SIZE, CUBE_SIZE, CUBE_SIZE); + const edgesGeometry = new THREE.EdgesGeometry(cubeGeometry); + const lineMaterial = new THREE.LineBasicMaterial({ color: COLOR_EDGE }); + + function createCube(color, matrixName, i, j) { + const cubeMaterial = new THREE.MeshPhongMaterial({ color: color }); + const cube = new THREE.Mesh(cubeGeometry, cubeMaterial); + const edges = new THREE.LineSegments(edgesGeometry, lineMaterial); + cube.add(edges); + + const hoverGeometry = new THREE.BoxGeometry(CUBE_SIZE * 1.05, CUBE_SIZE * 1.05, CUBE_SIZE * 1.05); + const hoverEdgesGeometry = new THREE.EdgesGeometry(hoverGeometry); + const hoverOutline = new THREE.LineSegments(hoverEdgesGeometry, new THREE.LineBasicMaterial({ color: COLOR_HOVER })); + hoverOutline.visible = false; + hoverOutline.name = 'hoverOutline'; + cube.add(hoverOutline); + + cube.name = `${matrixName}_cube_${i}_${j}`; + cube.matrixName = matrixName; + cube.matrixRow = i; + cube.matrixCol = j; + + return cube; + } + + function createMatrix(dimensions, position, color, matrixName) { + const matrix = new THREE.Group(); + matrix.userData.dimensions = dimensions; + for (let i = 0; i < dimensions[0]; i++) { + for (let j = 0; j < dimensions[1]; j++) { + const cube = createCube(color, matrixName, i, j); + cube.position.set( + position.x + j * (CUBE_SIZE + GAP), + position.y - i * (CUBE_SIZE + GAP), + position.z + ); + matrix.add(cube); + } + } + return matrix; + } + + const matrixA = createMatrix(input_shape, new THREE.Vector3(-10, 10, 0), COLOR_A, 'A'); + const matrixB = createMatrix(other_shape, new THREE.Vector3(0, 10, 0), COLOR_B, 'B'); + const matrixC = createMatrix(output_shape, new THREE.Vector3(-5, -4, 0), COLOR_C, 'C'); + + scene.add(matrixA); + scene.add(matrixB); + scene.add(matrixC); + + const center = new THREE.Vector3(); + const size = new THREE.Vector3(); + const box = new THREE.Box3().setFromObject(scene); + box.getCenter(center); + box.getSize(size); + const maxDim = Math.max(size.x, size.y, size.z); + const fov = camera.fov * (Math.PI / 180); + let cameraZ = Math.abs(maxDim / 2 / Math.tan(fov / 2)); + cameraZ *= 1.5; + + camera.position.set(center.x, center.y, center.z + cameraZ); + camera.lookAt(center); + + let isPaused = false; + + const totalFrames = input_shape[0] * other_shape[1]; + + function highlightCubes(matrix, indices, highlightColor) { + indices.forEach(([i, j]) => { + if (i >= 0 && i < matrix.userData.dimensions[0] && j >= 0 && j < matrix.userData.dimensions[1]) { + const index = i * matrix.userData.dimensions[1] + j; + if (index < matrix.children.length) { + matrix.children[index].material.color.copy(highlightColor); + + } + } + }); + } + + function resetColors() { + matrixA.children.forEach(cube => cube.material.color.copy(COLOR_A)); + matrixB.children.forEach(cube => cube.material.color.copy(COLOR_B)); + } + + function animate() { + requestAnimationFrame(animate); + + if (!isPaused && frame < totalFrames) { + resetColors(); + + const row = Math.floor(frame / other_shape[1]); + const col = frame % other_shape[1]; + currentStep = frame % totalSteps + 1; + + const highlightA = Array.from({ length: input_shape[1] }, (_, i) => [row, i]); + const highlightB = Array.from({ length: other_shape[0] }, (_, i) => [i, col]); + const highlightC = [[row, col]]; + + highlightCubes(matrixA, highlightA, COLOR_HIGHLIGHT); + highlightCubes(matrixB, highlightB, COLOR_HIGHLIGHT); + highlightCubes(matrixC, highlightC, COLOR_FILLED); + + frame++; + } + + renderer.render(scene, camera); + } + + function onResize() { + camera.aspect = containerElement.clientWidth / containerElement.clientHeight; + camera.updateProjectionMatrix(); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + } + + const controlPanel = document.createElement('div'); + controlPanel.style.position = 'absolute'; + controlPanel.style.bottom = '10px'; + controlPanel.style.left = '10px'; + controlPanel.style.display = 'flex'; + controlPanel.style.gap = '10px'; + + const playPauseButton = document.createElement('button'); + playPauseButton.textContent = 'Play/Pause'; + playPauseButton.addEventListener('click', () => { + isPaused = !isPaused; + }); + + const resetButton = document.createElement('button'); + resetButton.textContent = 'Reset'; + resetButton.addEventListener('click', () => { + frame = 0; + resetColors(); + }); + + controlPanel.appendChild(playPauseButton); + controlPanel.appendChild(resetButton); + + containerElement.appendChild(controlPanel); + + const PAN_SPEED = 0.1; + const TILT_SPEED = 0.02; + const ZOOM_SPEED = 0.5; + let cameraRotation = new THREE.Euler(0, 0, 0, 'YXZ'); + + function onKeyDown(event) { + switch (event.key.toLowerCase()) { + case 'w': + camera.position.y += PAN_SPEED; + break; + case 's': + camera.position.y -= PAN_SPEED; + break; + case 'a': + camera.position.x -= PAN_SPEED; + break; + case 'd': + camera.position.x += PAN_SPEED; + break; + case 'arrowup': + cameraRotation.x -= TILT_SPEED; + break; + case 'arrowdown': + cameraRotation.x += TILT_SPEED; + break; + case 'arrowleft': + cameraRotation.y -= TILT_SPEED; + break; + case 'arrowright': + cameraRotation.y += TILT_SPEED; + break; + case 'o': + camera.position.z += ZOOM_SPEED; + break; + case 'p': + camera.position.z -= ZOOM_SPEED; + break; + } + camera.setRotationFromEuler(cameraRotation); + camera.updateProjectionMatrix(); + } + + window.addEventListener('resize', onResize); + window.addEventListener('keydown', onKeyDown); + containerElement.addEventListener('mousemove', onMouseMove); + + animate(); + + + function cleanup() { + window.removeEventListener('resize', onResize); + window.removeEventListener('keydown', onKeyDown); + containerElement.removeEventListener('mousemove', onMouseMove); + containerElement.innerHTML = ''; + renderer.dispose(); + scene.clear(); + } + + return cleanup; +} diff --git a/triton_viz/static/store.js b/triton_viz/static/store.js new file mode 100644 index 0000000..470b327 --- /dev/null +++ b/triton_viz/static/store.js @@ -0,0 +1,221 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; +import { + setupScene, + setupGeometries, + createTensor, + calculateTensorSize, + updateCubeColor, + setupCamera, + setupEventListeners, + cameraControls +} from './store-utils.js'; + +export function createStoreVisualization(containerElement, op) { + + console.log(op.uuid); + fetch('/api/setop', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid: op.uuid }), + }) + .then(response => response.json()) + .then(data => console.log('Set current op:', data)) + .catch((error) => console.error('Error:', error)); + + let currentStep = 0; + let frame = 0; + let isPaused = false; + + const sideMenu = createSideMenu(containerElement); + let hoveredCube = null; + + const COLOR_GLOBAL = new THREE.Color(0.2, 0.2, 0.2); // Dark Gray + const COLOR_SLICE = new THREE.Color(0.0, 0.7, 1.0); // Cyan (starting color for global slice) + const COLOR_LEFT_SLICE = new THREE.Color(1.0, 0.0, 1.0); // Magenta (starting color for left slice) + const COLOR_LOADED = new THREE.Color(1.0, 0.8, 0.0); // Gold (final color for both slices) + const COLOR_BACKGROUND = new THREE.Color(0.0, 0.0, 0.0); // Black + + const { scene, camera, renderer } = setupScene(containerElement, COLOR_BACKGROUND); + const { cubeGeometry, edgesGeometry, lineMaterial } = setupGeometries(); + + const globalTensor = createTensor(op.global_shape, op.global_coords, COLOR_GLOBAL, 'Global', cubeGeometry, edgesGeometry, lineMaterial); + const sliceTensor = createTensor(op.slice_shape, op.slice_coords, COLOR_LEFT_SLICE, 'Slice', cubeGeometry, edgesGeometry, lineMaterial); + + // Position slice tensor + const globalSize = calculateTensorSize(op.global_shape); + sliceTensor.position.set(globalSize.x + 5, 0, 0); // Adjusted tensor spacing + + scene.add(globalTensor); + scene.add(sliceTensor); + + addLabels(scene, globalTensor, sliceTensor); + setupCamera(scene, camera); + + const totalFrames = op.global_coords.length * 2 + 30; + + const raycaster = new THREE.Raycaster(); + const mouse = new THREE.Vector2(); + + const onKeyDown = cameraControls(camera, new THREE.Euler(0, 0, 0, 'YXZ')); + setupEventListeners(containerElement, camera, renderer, onMouseMove, onKeyDown); + animate(); + + async function onMouseMove(event) { + mouse.x = (event.clientX / containerElement.clientWidth) * 2 - 1; + mouse.y = -(event.clientY / containerElement.clientHeight) * 2 + 1; + + raycaster.setFromCamera(mouse, camera); + + const allTensorChildren = [ + ...globalTensor.children, + ...sliceTensor.children + ]; + + const intersects = raycaster.intersectObjects(allTensorChildren, true); + + if (hoveredCube) { + hoveredCube.getObjectByName('hoverOutline').visible = false; + hoveredCube = null; + } + + if (intersects.length > 0) { + hoveredCube = intersects[0].object; + while (hoveredCube && !hoveredCube.tensorName) { + hoveredCube = hoveredCube.parent; + } + + if (hoveredCube) { + const hoverOutline = hoveredCube.getObjectByName('hoverOutline'); + if (hoverOutline) { + hoverOutline.visible = true; + } + + updateSideMenu(hoveredCube.tensorName, hoveredCube.tensor0, hoveredCube.tensor1, hoveredCube.tensor2, undefined); + + const res = await getElementValue(hoveredCube.tensorName, hoveredCube.tensor0, hoveredCube.tensor1, hoveredCube.tensor2); + + updateSideMenu(hoveredCube.tensorName, hoveredCube.tensor0, hoveredCube.tensor1, hoveredCube.tensor2, res.value); + + console.log(`Value: ${res.value}`); + } + } else { + updateSideMenu(null); + } + } + + function animate() { + requestAnimationFrame(animate); + + if (!isPaused && frame < totalFrames) { + const index = Math.floor(frame / 2); + const factor = (frame % 2) / 1.0; + + if (index < op.global_coords.length) { + const globalCoord = op.global_coords[index]; + const sliceCoord = op.slice_coords[index]; + + updateCubeColor(globalTensor, globalCoord, COLOR_GLOBAL, COLOR_SLICE, factor); + updateCubeColor(sliceTensor, sliceCoord, COLOR_LEFT_SLICE, COLOR_LOADED, factor); + + highlightCurrentOperation(globalTensor, globalCoord, sliceTensor, sliceCoord); + updateInfoPanel(globalCoord, sliceCoord, index); + } + + frame++; + } + + renderer.render(scene, camera); + } + + function highlightCurrentOperation(globalTensor, globalCoord, sliceTensor, sliceCoord) { + globalTensor.children.forEach(cube => cube.material.emissive.setHex(0x000000)); + sliceTensor.children.forEach(cube => cube.material.emissive.setHex(0x000000)); + + const globalCube = globalTensor.children.find(c => + c.tensor0 === globalCoord[0] && c.tensor1 === globalCoord[1] && c.tensor2 === globalCoord[2] + ); + const sliceCube = sliceTensor.children.find(c => + c.tensor0 === sliceCoord[0] && c.tensor1 === sliceCoord[1] && c.tensor2 === sliceCoord[2] + ); + + if (globalCube) globalCube.material.emissive.setHex(0x444444); + if (sliceCube) sliceCube.material.emissive.setHex(0x444444); + } + + async function getElementValue(tensorName, x, y, z) { + let uuid = op.uuid; + const response = await fetch('/api/getStoreValue', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid, tensorName, x, y, z }), + }); + return await response.json(); + } + + function updateSideMenu(tensorName, x, y, z, value) { + if (!tensorName) { + sideMenu.innerHTML = ''; + return; + } + + let dims = tensorName === 'Global' ? op.global_shape : op.slice_shape; + sideMenu.innerHTML = ` +

${tensorName} Tensor

+

X: ${x + 1}

+

Y: ${y + 1}

+

Z: ${z + 1}

+

Dimensions: ${dims.join(' x ')}

+

Value: ${value !== undefined ? value : 'Storeing...'}

+ `; + } + + function updateInfoPanel(globalCoord, sliceCoord, index) { + sideMenu.innerHTML = ` +

Current Operation

+

Global Coords: (${globalCoord.join(', ')})

+

Slice Coords: (${sliceCoord.join(', ')})

+

Progress: ${index + 1}/${op.global_coords.length}

+ `; + } + + function createSideMenu(container) { + const menu = document.createElement('div'); + menu.style.position = 'absolute'; + menu.style.top = '10px'; + menu.style.right = '10px'; + menu.style.width = '200px'; + menu.style.padding = '10px'; + menu.style.backgroundColor = 'rgba(0, 0, 0, 0.7)'; + menu.style.color = 'white'; + menu.style.fontFamily = 'Arial, sans-serif'; + menu.style.fontSize = '14px'; + menu.style.borderRadius = '5px'; + container.appendChild(menu); + return menu; + } + + function addLabels(scene, globalTensor, sliceTensor) { + addLabel(scene, "Global Tensor", globalTensor.position); + addLabel(scene, "Slice Tensor", sliceTensor.position); + } + + function addLabel(scene, text, position) { + const canvas = document.createElement('canvas'); + const context = canvas.getContext('2d'); + context.font = 'Bold 24px Arial'; + context.fillStyle = 'white'; + context.fillText(text, 0, 24); + + const texture = new THREE.CanvasTexture(canvas); + const material = new THREE.SpriteMaterial({ map: texture }); + const sprite = new THREE.Sprite(material); + sprite.position.set(position.x, position.y + 2, position.z); + sprite.scale.set(4, 2, 1); + scene.add(sprite); + } + +} diff --git a/triton_viz/static/visualization.js b/triton_viz/static/visualization.js new file mode 100644 index 0000000..e0d157c --- /dev/null +++ b/triton_viz/static/visualization.js @@ -0,0 +1,389 @@ +import { GridBlock } from './gridblock.js'; +import { createInfoPopup, showInfoPopup } from './infoPopup.js'; +let globalData; +let currentView = 'main'; +let canvas, ctx; +let maxX = 0, maxY = 0, maxZ = 0; +let sliders = [], zSlider, precomputeButton, kernelGrid; +let backButton; +let currentBlockData = null; +let isInitialized = false; +let containerElement; +let infoPopup; +let infoButton; +function switchToMainView() { + currentView = 'main'; + if (currentBlockData) { + currentBlockData.hideDetailedView(); + currentBlockData = null; + } + containerElement.style.pointerEvents = 'none'; + containerElement.style.display = 'none'; + containerElement.innerHTML = ''; + + canvas.style.display = 'block'; + draw(); +} + +function initializeApp() { + canvas = document.getElementById('canvas'); + if (!canvas) { + console.error('Canvas element not found'); + return; + } + ctx = canvas.getContext('2d'); + if (!ctx) { + console.error('Unable to get 2D context from canvas'); + return; + } + canvas.width = window.innerWidth; + canvas.height = window.innerHeight; + + canvas.addEventListener('mousedown', handleMouseEvent); + canvas.addEventListener('mouseup', handleMouseEvent); + canvas.addEventListener('mousemove', handleMouseEvent); + + containerElement = document.getElementById('visualization-container'); + if (!containerElement) { + console.error('Visualization container element not found'); + return; + } + + containerElement.style.pointerEvents = 'none'; + containerElement.style.display = 'none'; + + fetchData(); +} + +class Slider { + constructor(x, y, width, height, label, min_value = -1, max_value = 100) { + this.rect = { x, y, width, height }; + this.label = label; + this.min = min_value; + this.max = max_value; + this.value = min_value; + this.grabbed = false; + this.enabled = true; + } + + draw(ctx) { + if (!this.enabled) return; + ctx.fillStyle = '#3c3c46'; + ctx.fillRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + const buttonX = this.rect.x + (this.value - this.min) / (this.max - this.min) * this.rect.width; + ctx.fillStyle = '#c8c8c8'; + ctx.fillRect(buttonX - 5, this.rect.y - 2, 10, this.rect.height + 4); + + ctx.fillStyle = '#c8c8c8'; + ctx.font = '18px Arial'; + ctx.fillText(this.label, this.rect.x, this.rect.y - 10); + ctx.fillText(this.value.toString(), this.rect.x + this.rect.width + 10, this.rect.y + this.rect.height / 2 + 5); + } + + handleEvent(event) { + if (!this.enabled) return; + if (event.type === 'mousedown') { + if (this.isPointInside(event.offsetX, event.offsetY)) { + this.grabbed = true; + } + } else if (event.type === 'mouseup') { + this.grabbed = false; + } else if (event.type === 'mousemove' && this.grabbed) { + const mouseX = event.offsetX; + this.value = Math.round((mouseX - this.rect.x) / this.rect.width * (this.max - this.min) + this.min); + this.value = Math.max(this.min, Math.min(this.max, this.value)); + } + } + + isPointInside(x, y) { + return x >= this.rect.x && x <= this.rect.x + this.rect.width && + y >= this.rect.y && y <= this.rect.y + this.rect.height; + } +} + +class Button { + constructor(x, y, width, height, text, isIcon = false) { + this.rect = { x, y, width, height }; + this.text = text; + this.isIcon = isIcon; + this.color = '#3c3c46'; + this.hoverColor = '#50505a'; + this.clickColor = '#64646e'; + this.isHovered = false; + this.isClicked = false; + this.clickTime = 0; + } + + draw(ctx) { + let color = this.color; + if (this.isClicked && Date.now() - this.clickTime < 100) { + color = this.clickColor; + } else if (this.isHovered) { + color = this.hoverColor; + } + + ctx.fillStyle = color; + ctx.fillRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + ctx.strokeStyle = '#c8c8c8'; + ctx.strokeRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + + ctx.fillStyle = '#c8c8c8'; + ctx.font = this.isIcon ? 'bold 24px Arial' : '18px Arial'; + ctx.textAlign = 'center'; + ctx.textBaseline = 'middle'; + ctx.fillText(this.text, this.rect.x + this.rect.width / 2, this.rect.y + this.rect.height / 2); + } + + handleEvent(event) { + const { offsetX, offsetY } = event; + this.isHovered = this.isPointInside(offsetX, offsetY); + if (event.type === 'mousedown' && this.isHovered) { + this.isClicked = true; + this.clickTime = Date.now(); + console.log(`Button '${this.text}' clicked!`); + } else if (event.type === 'mouseup') { + this.isClicked = false; + } + } + + isPointInside(x, y) { + return x >= this.rect.x && x <= this.rect.x + this.rect.width && + y >= this.rect.y && y <= this.rect.y + this.rect.height; + } +} + +class KernelGrid { + constructor(x, y, width, height, gridSize, visualizationData) { + this.rect = { x, y, width, height }; + this.gridSize = gridSize; + this.visualizationData = visualizationData; + this.currentZ = 0; + this.blocks = []; + this.calculateBlockSize(); + this.createBlocks(); + this.selectedBlock = null; + this.filterValues = [-1, -1, -1]; // Default filter values for x, y, z + } + + calculateBlockSize() { + this.blockWidth = Math.floor(this.rect.width / this.gridSize[0]) - 1; + this.blockHeight = Math.floor(this.rect.height / this.gridSize[1]) - 1; + } + + createBlocks() { + this.blocks = []; + for (let y = 0; y < this.gridSize[1]; y++) { + for (let x = 0; x < this.gridSize[0]; x++) { + const blockX = this.rect.x + x * (this.blockWidth + 1); + const blockY = this.rect.y + y * (this.blockHeight + 1); + const gridKey = `(${x}, ${y}, ${this.currentZ})`; + const blockData = this.visualizationData[gridKey] || []; + const block = new GridBlock( + blockX, blockY, this.blockWidth, this.blockHeight, + x, y, this.currentZ, blockData, + switchToMainView, + containerElement, + canvas, + draw + ); + this.blocks.push(block); + } + } + } + + draw(ctx) { + ctx.fillStyle = '#F0F0F0'; + ctx.fillRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + this.blocks.forEach(block => { + if (this.shouldDrawBlock(block)) { + block.draw(ctx); + } + }); + } + + shouldDrawBlock(block) { + return (this.filterValues[0] === -1 || block.gridPosition.x === this.filterValues[0]) && + (this.filterValues[1] === -1 || block.gridPosition.y === this.filterValues[1]) && + (this.filterValues[2] === -1 || block.gridPosition.z === this.filterValues[2]); + } + + updateZ(z) { + this.currentZ = z; + this.filterValues[2] = z; + this.blocks.forEach(block => { + block.gridPosition.z = z; + const gridKey = `(${block.gridPosition.x}, ${block.gridPosition.y}, ${z})`; + block.blockData = this.visualizationData[gridKey] || []; + }); + } + + handleClick(x, y) { + const clickedBlock = this.blocks.find(block => + block.isPointInside(x, y) && this.shouldDrawBlock(block) + ); + if (clickedBlock) { + console.log(`Clicked block at (${clickedBlock.gridPosition.x}, ${clickedBlock.gridPosition.y}, ${clickedBlock.gridPosition.z})`); + if (this.selectedBlock) { + this.selectedBlock.hideDetailedView(); + } + this.selectedBlock = clickedBlock; + clickedBlock.showDetailedView(); + return clickedBlock; + } + return null; + } + + handleMouseMove(x, y) { + this.blocks.forEach(block => { + if (this.shouldDrawBlock(block)) { + block.handleMouseMove(x, y); + } else { + block.isHovered = false; + } + }); + } + + updateFilter(dimension, value) { + this.filterValues[dimension] = value; + } +} + +function determineMaxValues(visualizationData) { + maxX = 0; + maxY = 0; + maxZ = 0; + const keys = Object.keys(visualizationData); + keys.forEach(key => { + const [x, y, z] = key.replace(/[()]/g, '').split(', ').map(Number); + if (x > maxX) maxX = x; + if (y > maxY) maxY = y; + if (z > maxZ) maxZ = z; + }); +} + +function initializeUIElements() { + sliders = [ + new Slider(1300, 50, 250, 20, "Program Id 0", -1, maxX), + new Slider(1300, 120, 250, 20, "Program Id 1", -1, maxY), + new Slider(1300, 190, 250, 20, "Program Id 2", -1, maxZ) + ]; + + zSlider = new Slider(50, 860, 1200, 20, "Z-axis", 0, maxZ); + zSlider.enabled = maxZ > 0; + + precomputeButton = new Button(1300, 260, 250, 40, "Precompute"); + kernelGrid = new KernelGrid(50, 50, 1200, 800, [maxX + 1, maxY + 1, maxZ + 1], globalData.ops.visualization_data); + backButton = new Button(50, 50, 100, 40, "Back"); + const buttonSize = 40; + const margin = 10; + infoButton = new Button( + canvas.width - buttonSize - margin, + margin, + buttonSize, + buttonSize, + "i", + true + ); + + isInitialized = true; + + infoPopup = createInfoPopup(); +} + +function switchToTensorView(clickedBlock) { + currentView = 'tensor'; + currentBlockData = clickedBlock; + console.log("Switched to tensor view. Block data:", clickedBlock); + + containerElement.style.pointerEvents = 'auto'; + containerElement.style.display = 'block'; + clickedBlock.showDetailedView(); + + canvas.style.display = 'none'; +} + +function handleMouseEvent(event) { + if (!isInitialized) { + console.warn('UI elements not initialized yet'); + return; + } + if (infoButton) { + infoButton.handleEvent(event); + if (event.type === 'mousedown' && infoButton.isHovered) { + showInfoPopup(infoPopup); + } + } + const { offsetX, offsetY } = event; + if (currentView === 'main') { + sliders.forEach((slider, index) => { + slider.handleEvent(event); + if (kernelGrid) { + kernelGrid.updateFilter(index, slider.value); + } + }); + if (zSlider && zSlider.enabled) { + zSlider.handleEvent(event); + if (kernelGrid) { + kernelGrid.updateZ(zSlider.value); + } + } + if (precomputeButton) { + precomputeButton.handleEvent(event); + } + if (kernelGrid) { + kernelGrid.handleMouseMove(offsetX, offsetY); + if (event.type === 'mousedown') { + const clickedBlock = kernelGrid.handleClick(offsetX, offsetY); + if (clickedBlock) { + switchToTensorView(clickedBlock); + } + } + } + } else if (currentView === 'tensor') { + if (backButton) { + backButton.handleEvent(event); + if (event.type === 'mousedown' && backButton.isHovered) { + switchToMainView(); + } + } + } + draw(); +} + +function draw() { + if (!ctx) { + console.error('Canvas context not available'); + return; + } + + ctx.fillStyle = '#1e1e28'; + ctx.fillRect(0, 0, canvas.width, canvas.height); + + if (currentView === 'main' || currentView === 'main') { + if (kernelGrid) kernelGrid.draw(ctx); + sliders.forEach(slider => slider.draw(ctx)); + if (zSlider && zSlider.enabled) { + zSlider.draw(ctx); + } + if (precomputeButton) precomputeButton.draw(ctx); + if (infoButton) { + infoButton.draw(ctx); + } + } +} + +async function fetchData() { + try { + const response = await fetch('/api/data'); + globalData = await response.json(); + console.log(globalData); + + determineMaxValues(globalData.ops.visualization_data); + initializeUIElements(); + draw(); + } catch (error) { + console.error('Error fetching data:', error); + } +} + +initializeApp(); \ No newline at end of file diff --git a/triton_viz/templates/index.html b/triton_viz/templates/index.html new file mode 100644 index 0000000..b8eaebe --- /dev/null +++ b/triton_viz/templates/index.html @@ -0,0 +1,47 @@ + + + + + + Kernel Launch Configuration Visualization + + + + +
+ + + + diff --git a/triton_viz/tooltip.py b/triton_viz/tooltip.py index 2a3f069..8ff971d 100644 --- a/triton_viz/tooltip.py +++ b/triton_viz/tooltip.py @@ -7,13 +7,18 @@ "ExpandDims": "Shows the total number of expand_dims operations performed in this kernel.", "Dot": "Shows the total number of dot operations performed in this kernel.", "Reduce": "Shows the total number of reduce operations performed in this kernel.", - "Total number of bytes loaded": "Shows the total number of bytes loaded (mask=True).Note:On GPUs, this metric does not equate to the total number of bytes loaded from global memory (DRAM), as some data accesses may be handled through GPU caches.", + "Total number of bytes loaded": "Shows the total number of bytes loaded (mask=True). Note: On GPUs, this metric does not equate to the total number of bytes loaded from global memory (DRAM), as some data accesses may be handled through GPU caches.", "Masked Load Ratio": "Ratio of total number of bytes loaded (mask=True)/total number of bytes loaded (mask=True) + (mask=False).", "Total number of bytes stored": "Shows the total number of bytes stored (mask=True).", "Masked Store Ratio": "Ratio of total number of bytes stored (mask=True)/total number of bytes stored (mask=True) + (mask=False).", } +def get_tooltip_data(df): + """Return the tooltip data in a format suitable for JSON serialization.""" + return df.to_dict() + + def create_tooltip(df): styles = """