diff --git a/examples/3dims.py b/examples/3dims.py new file mode 100644 index 0000000..bfe766c --- /dev/null +++ b/examples/3dims.py @@ -0,0 +1,99 @@ +import torch +import triton +import triton.language as tl + + + +import triton_viz +BLOCK_SIZE_X = 32 +BLOCK_SIZE_Y = 8 +BLOCK_SIZE_Z = 4 + +@triton_viz.trace +@triton.jit +def add_3d_slices_kernel( + input_ptr1, input_ptr2, output_ptr, + stride_x, stride_y, stride_z, + slice_x, slice_y, slice_z, + BLOCK_SIZE_X: tl.constexpr, BLOCK_SIZE_Y: tl.constexpr, BLOCK_SIZE_Z: tl.constexpr +): + # Compute the 3D position in the output tensor + pid_x = tl.program_id(0) + pid_y = tl.program_id(1) + pid_z = tl.program_id(2) + + # Compute the starting position for this block + x_start = pid_x * BLOCK_SIZE_X + y_start = pid_y * BLOCK_SIZE_Y + z_start = pid_z * BLOCK_SIZE_Z + + # Compute offsets within the block + x_offsets = x_start + tl.arange(0, BLOCK_SIZE_X) + y_offsets = y_start + tl.arange(0, BLOCK_SIZE_Y) + z_offsets = z_start + tl.arange(0, BLOCK_SIZE_Z) + + # Create a mask to handle boundary conditions + mask = (x_offsets < slice_x) & (y_offsets < slice_y)[:, None] & (z_offsets < slice_z)[:, None, None] + + # Compute the input and output offsets + offsets = ( + z_offsets[:, None, None] * stride_z + + y_offsets[:, None] * stride_y + + x_offsets * stride_x + ) + + # Load input slices + slice1 = tl.load(input_ptr1 + offsets, mask=mask) + slice2 = tl.load(input_ptr2 + offsets, mask=mask) + + # Perform addition + result = slice1 + slice2 + + # Store the result + tl.store(output_ptr + offsets, result, mask=mask) + +def add_3d_slices(input1, input2, output): + # Get tensor shapes + slice_z, slice_y, slice_x = input1.shape + + # Compute strides + stride_z, stride_y, stride_x = input1.stride() + + # Determine grid size + grid = ( + triton.cdiv(slice_x, BLOCK_SIZE_X), + triton.cdiv(slice_y, BLOCK_SIZE_Y), + triton.cdiv(slice_z, BLOCK_SIZE_Z) + ) + + # Launch kernel + add_3d_slices_kernel[grid]( + input1, input2, output, + stride_x, stride_y, stride_z, + slice_x, slice_y, slice_z, + BLOCK_SIZE_X=BLOCK_SIZE_X, + BLOCK_SIZE_Y=BLOCK_SIZE_Y, + BLOCK_SIZE_Z=BLOCK_SIZE_Z + ) + +if __name__ == "__main__": + # Set random seed for reproducibility + torch.manual_seed(0) + + # Create example input tensor + input1 = torch.randn(16, 16, 32, device='cpu') + input2 = torch.randn(16, 16, 32, device='cpu') + output = torch.empty_like(input1) + + # Call the kernel + add_3d_slices(input1, input2, output) + triton_viz.launch() + + # Verify the result + expected_output = input1 + input2 + assert torch.allclose(output, expected_output), "Kernel output does not match expected result" + + + + + diff --git a/setup.py b/setup.py index 6898fed..0ea16c3 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="triton-viz", - version="0.1", + version="0.2", packages=find_packages(), description="A visualization tool for Triton", author="Deep Learning Profiling Tools Team", @@ -11,10 +11,11 @@ install_requires=[ "setuptools", "triton", - "gradio", - "chalk-diagrams @ git+https://github.com/chalk-diagrams/chalk.git", + "flask", "pyarrow", "pre-commit", "pytest", + "flask_cloudflared", + "requests", ], ) diff --git a/triton_viz/__init__.py b/triton_viz/__init__.py index 737dd36..7c5bab8 100644 --- a/triton_viz/__init__.py +++ b/triton_viz/__init__.py @@ -1,5 +1,5 @@ from .trace import trace, dump, sample -from .draw import collect_grid, draw_record +from .draw import collect_grid from .interface import launch -__all__ = ["trace", "launch", "dump", "sample", "collect_grid", "draw_record"] +__all__ = ["trace", "launch", "dump", "sample", "collect_grid"] diff --git a/triton_viz/data.py b/triton_viz/data.py index 819a398..ad9678c 100644 --- a/triton_viz/data.py +++ b/triton_viz/data.py @@ -1,9 +1,11 @@ from dataclasses import dataclass, field -from typing import List, Tuple, Any +from typing import List, Tuple, Any, Dict import traceback import numpy.typing as npt import numpy as np +import torch + @dataclass class Op: @@ -73,6 +75,15 @@ class Dot(Op): input_shape: Tuple other_shape: Tuple output_shape: Tuple + input_data: List[List[float]] + other_data: List[List[float]] + intermediate_results: Dict[Tuple[int, int], float] = field( + default_factory=dict + ) # Only storing the result now + + def update_intermediate(self, row: int, col: int, result: float): + # Store only the result as a float + self.intermediate_results[(row, col)] = result @dataclass @@ -91,6 +102,7 @@ class Tensor: stride: Tuple shape: Tuple element_size: int + data: torch.Tensor @dataclass diff --git a/triton_viz/draw.py b/triton_viz/draw.py index fda62c0..2de3c52 100644 --- a/triton_viz/draw.py +++ b/triton_viz/draw.py @@ -1,69 +1,10 @@ -from colour import Color -from triton_viz.data import ( - Tensor, - Grid, - Store, - Load, - Op, - MakeRange, - Reduce, - Dot, -) -from .interpreter import record_builder +from triton_viz.data import Tensor, Grid, Store, Load, Dot, ExpandDims +import uuid import numpy as np -import numpy.typing as npt -import planar -import math -import chalk -from typing import Tuple, Union, Optional, List, Dict -from chalk import Diagram, rectangle, text, hcat, vcat, empty, Path, Trail, V2, concat -from dataclasses import dataclass -from numpy.typing import ArrayLike -import sys - -sys.setrecursionlimit(100000) - - -planar.EPSILON = 0.0 -chalk.set_svg_draw_height(500) -BG = Color("white") -WHITE = Color("white") -DEFAULT = Color("grey") -BLACK = Color("black") -GREY = Color("grey") -palette = [ - "#f29f05", - "#f25c05", - "#d6568c", - "#4d8584", - "#a62f03", - "#400d01", - "#274001", - "#828a00", -] -ACTIVE = [Color(p) for p in palette] - -MRATIO = 1 / 3 - - -# Generic render helpers - - -def box(d: Diagram, width: float, height: float, outer=0.2) -> Diagram: - "Put diagram in a box of shape height, width" - h, w = d.get_envelope().height, d.get_envelope().width - m = max(h, w) - back = rectangle(outer + m, outer + m).line_width(0).fill_color(BG).center_xy() - d = (back + d.center_xy()).with_envelope(back) - return d.scale_x(width / m).scale_y(height / m) - - -def reshape(d: Diagram) -> Diagram: - "Use log-scale if ratio is too sharp" - h, w = d.get_envelope().height, d.get_envelope().width - if (h / w > MRATIO) or (w / h > MRATIO): - d = d.scale_y(math.log(h + 1, 2) / h).scale_x(math.log(w + 1, 2) / w) - return d +import torch +from .interpreter import record_builder + +from typing import Tuple, List def collect_grid(): @@ -96,141 +37,37 @@ def collect_launch(launch): return all_grids, tensor_table, failures -def draw_record(program_record, tensor_table, output): - return draw_launch(program_record, tensor_table, output) - - -def draw_launch(program_records, tensor_table, base) -> Diagram: - def draw(x): - "Dispatch" - if isinstance(x, Tensor): - return draw_tensor(x) - if isinstance(x, Grid): - return draw_grid(x) - if isinstance(x, Store): - return draw_store(x, tensor_table) - if isinstance(x, Load): - return draw_load(x, tensor_table) - if isinstance(x, Op): - return draw_op(x) - if isinstance(x, MakeRange): - return draw_make_range(x) - if isinstance(x, Reduce): - return draw_reduce(x) - if isinstance(x, Dot): - return None # draw_dot(x) - - def draw_record(x): - "Render one record" - y = draw(x) - if y is None: - return empty() - - return (chalk.vstrut(0.2) / y).center_xy() - - records = [] - for r in program_records: - dr = draw_record(r) - # env = dr.get_envelope() - # dr = dr.center_xy().with_envelope(rectangle(env.width, env.height).center_xy()) - records.append(dr) - - dr = vcat(records) - dr = dr.center_xy() - env = dr.get_envelope() - dr = rectangle(env.width + 1, env.height + 1).fill_color(BG).center_xy() + dr - dr.render_svg(base, 2500) - return env.width, env.height - - -def delinearize(shape: Tuple, x: npt.NDArray, dtype, mask) -> List[npt.NDArray]: - if len(shape) == 1: - shape = (1, 1, shape[0]) - x = x.copy() // (dtype.element_ty.primitive_bitwidth // 8) - vals = [] - for s in list(reversed(shape[1:])) + [10000]: - vals.append(((x % s) * mask - (1 - mask)).ravel()) - x = x // s - return vals - - -trail = Trail.from_offsets([V2(0, 1), V2(1, 0), V2(0, -1), V2(-1, 0)], closed=True) - - -def cover( - shape: Tuple, dtype, load: Tensor, mask: npt.NDArray, color: Color -) -> Diagram: - shape = make_3d(shape) - "Draw the values from load on top of the loading tensor" - x, y, z = delinearize(shape, load, dtype, mask) - return draw_tensor_3d(shape, z, y, x, color) - - -def pair_draw(x: Diagram, y: Diagram, command: str) -> Diagram: - "Draw two diagrams next to each other with a command in the middle." - return hcat([box(x, 3, 2.5), box(y, 3, 2.5)], 1).center_xy() + text( - command, 0.2 - ).fill_color(BLACK).line_width(0).translate(0, -1) - - -# Individual renderers - - -def draw_tensor(x: Tensor) -> Optional[Diagram]: - return None - - -def draw_grid(x: Grid) -> Optional[Diagram]: - return None - - -def draw_make_range(x: MakeRange) -> Optional[Diagram]: - return None - - -def draw_reduce(x: Reduce) -> Optional[Diagram]: - color = ACTIVE[0] - inp = draw_tensor_3d(make_3d(x.input_shape), None, None, None, color) - if x.index == 0 and len(x.input_shape) == 2: - inp = hcat( - [ - rectangle(0.1, inp.get_envelope().height) - .align_t() - .line_width(0) - .fill_color(BLACK), - inp, - ], - 0.5, - ) - else: - inp = vcat( - [ - rectangle(inp.get_envelope().width, 0.1) - .align_l() - .line_width(0) - .fill_color(BLACK), - inp, - ], - 0.5, - ) - out = draw_tensor_3d(x.output_shape, None, None, None, color) - return pair_draw(reshape(inp), reshape(out), x.op) +def extract_load_coords( + record: Load, global_tensor: Tensor +) -> Tuple[List[Tuple[float, float, float]], List[Tuple[float, float, float]]]: + # Extract coordinates for the global tensor + global_shape = make_3d(global_tensor.shape) + global_z, global_y, global_x = delinearized( + global_shape, + record.original_offsets, + global_tensor.dtype, + record.original_masks, + ) + global_coords = [ + (float(xi), float(yi), float(zi)) + for xi, yi, zi in zip(global_z, global_y, global_x) + if xi != -1 and yi != -1 and zi != -1 + ] -def draw_load(x, tensor_table) -> Optional[Diagram]: - inp, out = store_load(x, tensor_table) - out = reshape(out) - return pair_draw(inp, out, "load") + # Extract coordinates for the slice tensor + slice_shape = make_3d(record.shape) + slice_z, slice_y, slice_x = record.original_masks.reshape(*slice_shape).nonzero() + slice_coords = [ + (float(xi), float(yi), float(zi)) + for xi, yi, zi in zip(slice_x, slice_y, slice_z) + ] -def draw_store(x, tensor_table) -> Optional[Diagram]: - inp, out = store_load(x, tensor_table) - out = reshape(out) - return pair_draw(out, inp, "store") + return global_coords, slice_coords -def make_3d(shape): - "Make a 3d shape" +def make_3d(shape: Tuple[int, ...]): if len(shape) == 1: return (1, 1, shape[0]) if len(shape) == 2: @@ -238,233 +75,114 @@ def make_3d(shape): return shape -def store_load( - x: Union[Store, Load], tensor_table: Dict[int, Tuple[Tensor, int]] -) -> Tuple[Diagram, Diagram]: - tensor, tensor_id = tensor_table[x.ptr] - # inp = base_tensor(tensor.shape, DEFAULT) - color = ACTIVE[tensor_id] - invalid = x.invalid_access_masks.any() - if invalid: - color = Color("red") - inp = cover(tensor.shape, tensor.dtype, x.original_offsets, x.original_masks, color) - inp = reshape(inp) - s = make_3d(x.original_offsets.shape) - a, b, c = x.original_masks.reshape(*s).nonzero() - out = draw_tensor_3d(s, a, b, c, color) - return inp, out - - -def draw_op(x: Op) -> Optional[Diagram]: - return None - - -def draw_dot(x: Dot) -> Optional[Diagram]: - if x.input_shape == (1,): - return None - inp = draw_tensor_3d(x.input_shape[0], None, None, None) - # inp = reshape(base_tensor(x.input_shape[0], color=ACTIVE)) - # inp = add_whiskers(inp, x.input_shape[0]) - inp2 = draw_tensor_3d(x.input_shape[1], None, None, None) - # inp2 = reshape(base_tensor(x.input_shape[0], color=ACTIVE)) - # inp2 = add_whiskers(inp2, x.input_shape[0]) - out = draw_tensor_3d(x.output_shape, None, None, None) - # out = reshape(base_tensor(x.output_shape, color=ACTIVE)) - # out = add_whiskers(out, x.output_shape) - return hcat( - [box(inp, 1.5, 2), box(inp2, 1.5, 2), box(out, 1.5, 2)], 1 - ).center_xy() + text("dot", 0.2).fill_color(BLACK).line_width(0).translate(0, -1) - - -# For 3d - - -def lookAt(eye: ArrayLike, center: ArrayLike, up: ArrayLike): - "Python version of the haskell lookAt function in linear.projections" - f = (center - eye) / np.linalg.norm(center - eye) - s = np.cross(f, up) / np.linalg.norm(np.cross(f, up)) - u = np.cross(s, f) - return np.array([[*s, 0], [*u, 0], [*-f, 0], [0, 0, 0, 1]]) - - -def scale3(x, y, z): - return np.array([[x, 0, 0, 0], [0, y, 0, 0], [0, 0, z, 0], [0, 0, 0, 1]]) - - -@dataclass -class D3: - x: float - y: float - z: float - - def to_np(self): - return np.array([self.x, self.y, self.z]) - - -V3 = D3 - - -def homogeneous(trails: List[List[D3]]): - "Convert list of directions to a np.array of homogeneous coordinates" - return np.array([[[*o.to_np(), 1] for o in offsets] for offsets in trails]) - - -def cube(): - "3 faces of a cube drawn as offsets from the origin." - return homogeneous( - [ - [D3(*v) for v in offset] - for offset in [ - [(1, 0, 0), (0, 1, 0), (-1, 0, 0), (0, -1, 0)], - [(1, 0, 0), (0, 0, 1), (-1, 0, 0), (0, 0, -1)], - [(0, 0, 1), (0, 1, 0), (0, 0, -1), (0, -1, 0)], - ] - ] - ) - - -def to_trail(trail: ArrayLike, locations: ArrayLike): - return [ - ( - Path( - [ - Trail.from_offsets([V2(*v[:2]) for v in trail]) - .close() - .at(V2(*loc[:2])) - ] - ), - loc[2], - ) - for loc in locations - ] - - -def project(projection, shape3, positions): - p = homogeneous([positions for _ in range(shape3.shape[0])]) - locations = p @ projection.T - trails = shape3 @ projection.T - return [out for t, loc in zip(trails, locations) for out in to_trail(t, loc)] - - -def draw_tensor_3d(shape, a, b, c, color=WHITE): - shape = make_3d(shape) - - # Big Cube - s = scale3(*shape) - big_cube = cube() @ s.T - back = scale3(0, shape[1], shape[2]) - back_cube = cube() @ back.T - - # Isometric projection of tensor - projection = lookAt( - V3(-1.0, -0.3, -0.15).to_np(), - V3(0, 0, 0).to_np(), - V3(0, 1, 0).to_np(), - ) - outer = project(projection, big_cube, [V3(0, 0, 0)]) - outer2 = project(projection, back_cube, [V3(shape[0], 0, 0)]) - d = ( - concat([p.stroke().fill_color(GREY).fill_opacity(0.1) for p, _ in outer2]) - .line_width(0.005) - .line_color(GREY) - ) - d += ( - concat([p.stroke().fill_color(GREY).fill_opacity(0.05) for p, _ in outer]) - .line_width(0.01) - .line_color(BLACK) - ) - if a is not None: - out = group(a, b, c) - d2 = [ - (b, loc) - for i in range(len(out)) - for b, loc in make_cube(projection, out[i][0], out[i][1], color) - ] - d2.sort(key=lambda x: x[1], reverse=True) - d2 = concat([b.with_envelope(empty()) for b, _ in d2]) - d = d2.with_envelope(d) + d - return d - - -def lines(s): - "Draw lines to mimic a cube of cubes" - bs = [np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])] - return [ - homogeneous( - [ - [D3(*p) for _ in range(s[i]) for p in [a / s[i], b, -b]] - for b in bs - if not np.all(a == b) - ] +def delinearized( + shape: Tuple[int, int, int], x: np.ndarray, dtype, mask +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + x = x.copy() // (dtype.element_ty.primitive_bitwidth // 8) + z = ((x // (shape[1] * shape[2])) * mask - (1 - mask)).ravel() + y = (((x // shape[2]) % shape[1]) * mask - (1 - mask)).ravel() + x = ((x % shape[2]) * mask - (1 - mask)).ravel() + return z, y, x + + +def prepare_visualization_data(program_records, tensor_table): + """Prepare visualization data for the frntend and raw tensor data for the server.""" + # global idx + visualization_data = [] + raw_tensor_data = {} + for record in program_records: + record_uuid = str(uuid.uuid4())[:8] + + if isinstance(record, ExpandDims): + print(record.input_shape, record.output_shape, record.index) + if isinstance(record, Dot): + visualization_data.append( + { + "type": "Dot", + "input_shape": record.input_shape, + "other_shape": record.other_shape, + "output_shape": record.output_shape, + "uuid": record_uuid, + } + ) + + raw_tensor_data[record_uuid] = { + "input_data": torch.tensor(record.input_data), + "other_data": torch.tensor(record.other_data), + "intermediate_results": record.intermediate_results, + } + + elif isinstance(record, Load): + global_tensor, slice_tensor = tensor_table[record.ptr] + print(global_tensor) + global_coords, slice_coords = extract_load_coords(record, global_tensor) + + visualization_data.append( + { + "type": "Load", + "global_shape": global_tensor.shape, + "slice_shape": record.shape, + "global_coords": global_coords, + "slice_coords": slice_coords, + "uuid": record_uuid, + } + ) + + raw_tensor_data[record_uuid] = { + "global_tensor": global_tensor.data.cpu(), # Ensure it's on CPU + "dims": len(global_tensor.data.cpu().shape), + } + print(record.shape) + + elif isinstance(record, Store): + global_tensor, slice_tensor = tensor_table[record.ptr] + + global_coords, slice_coords = extract_load_coords(record, global_tensor) + + visualization_data.append( + { + "type": "Store", + "global_shape": global_tensor.shape, + "slice_shape": record.shape, + "global_coords": global_coords, + "slice_coords": slice_coords, + "uuid": record_uuid, + } + ) + + return visualization_data, raw_tensor_data, "" + + +def get_visualization_data(): + """Return the visualization data and raw tensor data.""" + records, tensor_table, failures = collect_grid() + visualization_data = {} + raw_tensor_data = {} + + for grid_idx, program_records in records.items(): + viz_data, raw_data, kernel_src = prepare_visualization_data( + program_records, tensor_table ) - for i, a in enumerate(bs) - ] - - -def make_cube(projection, start, end, color): - "Draws a cube from start position to end position." - start = np.array(start).astype(int) - end = np.array(end).astype(int) - s2 = end - start + 1 - s = scale3(*s2) - small_cube = cube() @ s.T - loc = [ - project(projection, l2 @ s.T, [V3(*start)]) - for l2 in lines(s2) - if l2.shape[1] > 0 - ] - outer2 = project(projection, small_cube, [V3(*start)]) - ls = loc - box = [ - (p.stroke().fill_color(color).fill_opacity(0.4).line_width(0), loc) - for p, loc in outer2 - ] - line = [ - (p.stroke().line_width(0.001).line_color(BLACK), l_) - for loc in ls - for p, l_ in loc - ] - return [(b, loc) for b, loc in box + line] - - -def group( - x: ArrayLike, y: ArrayLike, z: ArrayLike -) -> List[Tuple[Tuple[float, float, float], Tuple[float, float, float]]]: - "Groups together cubes into bigger cubes" - x = list(zip(zip(x, y, z), zip(x, y, z))) - x = [(a, b) for a, b in x if not (a[0] == -1 and a[1] == -1 and a[2] == -1)] - - start = x - - def remove_dups(ls): - "Remove duplicates" - out = [] - for y in ls: - if not out or y != out[-1]: - out.append(y) - return out - - for j in range(2, -1, -1): - x = remove_dups(start) - start = [] - while True: - if len(x) <= 1: - break - _, _, rest = x[0], x[1], x[2:] - m = 0 - for k in range(2): - a = x[0][k] - b = x[1][k] - if ( - (k == 0 or a[j % 3] == b[j % 3] - 1) - and a[(j + 1) % 3] == b[(j + 1) % 3] - and a[(j + 2) % 3] == b[(j + 2) % 3] - ): - m += 1 - if m == 2: - x = [[x[0][0], x[1][1]]] + rest - else: - start.append(x[0]) - x = [x[1]] + rest - start += x - return start + visualization_data[str(grid_idx)] = viz_data + raw_tensor_data.update(raw_data) + + # Get the kernel source code + + return { + "visualization_data": visualization_data, + "raw_tensor_data": raw_tensor_data, + "failures": failures, + "kernel_src": kernel_src, + } + + +def serialize_for_json(obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, torch.Tensor): + return obj.cpu().numpy().tolist() + raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") diff --git a/triton_viz/interface.py b/triton_viz/interface.py index 472f01d..e89e29a 100644 --- a/triton_viz/interface.py +++ b/triton_viz/interface.py @@ -1,99 +1,203 @@ -import gradio as gr -import triton_viz -import tempfile +import threading +from flask import Flask, render_template, jsonify, request from .analysis import analyze_records -from .tooltip import create_tooltip +from .draw import get_visualization_data +from .tooltip import get_tooltip_data import pandas as pd +import os +import torch +from flask_cloudflared import _run_cloudflared +import requests +import time +app = Flask( + __name__, + template_folder=os.path.join(os.path.dirname(__file__), "templates"), + static_folder=os.path.join(os.path.dirname(__file__), "static"), +) -def launch(share=True): - cache = {} +# Global variables to store the data +global_data = None +raw_tensor_data = None +precomputed_c_values = {} +current_fullscreen_op = None + + +def precompute_c_values(op_data): + input_data = op_data["input_data"] + other_data = op_data["other_data"] + rows, inner_dim = input_data.shape + cols = other_data.shape[1] + + precomputed = {} + for i in range(rows): + for j in range(cols): + precomputed[(i, j)] = [0] * (inner_dim + 1) + for k in range(1, inner_dim + 1): + precomputed[(i, j)][k] = torch.dot( + input_data[i, :k], other_data[:k, j] + ).item() + + return precomputed + + +def update_global_data(): + global global_data, raw_tensor_data, precomputed_c_values analysis_data = analyze_records() - program_records, tt, failures = triton_viz.collect_grid() - m = [0, 0, 0] - size = [0, 0] - for k in program_records.keys(): - m[0] = max(k[0] + 1, m[0]) - m[1] = max(k[1] + 1, m[1]) - m[2] = max(k[2] + 1, m[2]) - w, h = triton_viz.draw_record(program_records[(0, 0, 0)], tt, "tmp.svg") - size[0] = w - size[1] = h - height = 600 * size[1] / size[0] - with gr.Blocks( - css=".gradio-container button {overflow: auto} img.with-caption {height: %fpx !important; } .thumbnails { display: none; } " - % height - ) as demo: - with gr.Row(): - with gr.Column(scale=3, min_width=500): - img = gr.Gallery( - height=500, - min_width=500, - show_label=False, - selected_index=0, - preview=True, - object_fit="cover", - ) - with gr.Column(scale=1): - s1 = gr.Slider(0, m[0] - 1, value=0, step=1, label="Program Id 0") - s2 = gr.Slider(0, m[1] - 1, value=0, step=1, label="Program Id 1") - s3 = gr.Slider(0, m[2] - 1, value=0, step=1, label="Program Id 2") - b1 = gr.Button("Precompute") - gr.Markdown("## Analysis") - df = pd.DataFrame(analysis_data, columns=["Metric", "Value"]) - analysis_with_tooltip = create_tooltip(df) - gr.HTML(analysis_with_tooltip) - if failures: - gr.Markdown( - show_label=False, - value="## Invalid memory access in " - + "\n * " - + "\n* ".join(list(map(str, failures.keys()))), - ) - - def cache_block(idx): - name = tempfile.NamedTemporaryFile(suffix=".svg") - w, h = triton_viz.draw_record(program_records[idx], tt, name.name) - size[0] = w - size[1] = h - cache[idx] = (name, len(cache)) - - def update(inp): - a = inp[s1] - b = inp[s2] - c = inp[s3] - idx = (a, b, c) - - if idx not in cache: - cache_block(idx) - return gr.Gallery( - value=[(cache[k][0].name, str(k)) for k in cache.keys()], - selected_index=cache[idx][1], - height=700, - ), gr.Slider() - # * size[1]/size[0] - return gr.Gallery(selected_index=cache[idx][1]), gr.Slider() - - def precompute(inp): - a = inp[s1] - b = inp[s2] - c = inp[s3] - idx = (a, b, c) - for i in range(m[0]): - for j in range(m[1]): - for k in range(m[2]): - if (i, j, k) not in cache: - cache_block((i, j, k)) - return gr.Gallery( - value=[(cache[k][0].name, str(k)) for k in cache.keys()], - selected_index=cache[idx][1], - ) - - s1.change(update, inputs={s1, s2, s3}, outputs=[img, b1], show_progress=False) - s2.change(update, inputs={s1, s2, s3}, outputs=[img, b1], show_progress=False) - s3.change(update, inputs={s1, s2, s3}, outputs=[img, b1], show_progress=False) - b1.click(precompute, inputs={s1, s2, s3}, outputs=img, show_progress=True) - demo.load(update, inputs={s1, s2, s3}, outputs=[img, b1]) - - demo.launch(share=share, debug=False, height=800, quiet=True, show_api=False) - return failures + viz_data = get_visualization_data() + global_data = { + "ops": { + "visualization_data": viz_data["visualization_data"], + "failures": viz_data["failures"], + "kernel_src": viz_data["kernel_src"], + } + } + raw_tensor_data = viz_data["raw_tensor_data"] + + # Precompute C values for each Dot operation + precomputed_c_values = {} + for uuid, op_data in raw_tensor_data.items(): + if "input_data" in op_data and "other_data" in op_data: + precomputed_c_values[uuid] = precompute_c_values(op_data) + + df = pd.DataFrame(analysis_data, columns=["Metric", "Value"]) + analysis_with_tooltip = get_tooltip_data(df) + global_data["analysis"] = analysis_with_tooltip + + +@app.route("/") +def index(): + return render_template("index.html") + + +@app.route("/api/data") +def get_data(): + global global_data + if global_data is None: + update_global_data() + return jsonify(global_data) + + +@app.route("/api/update_data") +def update_data(): + update_global_data() + return jsonify({"status": "Data updated successfully"}) + + +@app.route("/api/setop", methods=["POST"]) +def set_current_op(): + global current_fullscreen_op + data = request.json + current_fullscreen_op = data.get("uuid") + return jsonify( + {"status": "Current op set successfully", "uuid": current_fullscreen_op} + ) + + +@app.route("/api/getValue", methods=["POST"]) +def get_value(): + global raw_tensor_data, precomputed_c_values, current_fullscreen_op + print(current_fullscreen_op) + data = request.json + uuid = data.get("uuid") + matrix_name = data.get("matrixName") + row = data.get("row") + col = data.get("col") + + if uuid not in raw_tensor_data: + return jsonify({"error": "Operation not found"}), 404 + + op_data = raw_tensor_data[uuid] + + if matrix_name == "A": + value = ( + op_data["input_data"][row, col].item() if "input_data" in op_data else None + ) + return jsonify({"value": value}) + elif matrix_name == "B": + value = ( + op_data["other_data"][row, col].item() if "other_data" in op_data else None + ) + return jsonify({"value": value}) + elif matrix_name == "C": + current_step = data.get("currentStep", 0) + + if uuid not in precomputed_c_values: + return jsonify({"error": "Precomputed values not found"}), 404 + + precomputed = precomputed_c_values[uuid] + current_value = precomputed[(row, col)][current_step] + + return jsonify( + { + "value": current_value, + } + ) + else: + return jsonify({"error": "Invalid matrix name"}), 400 + + +@app.route("/api/getLoadValue", methods=["POST"]) +def get_load_value(): + global raw_tensor_data, current_fullscreen_op + + data = request.json + uuid = data.get("uuid") + x = data.get("x") + y = data.get("y") + z = data.get("z") + print(x, y, z) + if uuid is None or uuid not in raw_tensor_data: + return jsonify({"error": "Operation not found"}), 404 + + op_data = raw_tensor_data[uuid] + + if "global_tensor" in op_data and ( + x is not None and y is not None and z is not None + ): + try: + value = 0.0 + if op_data["dims"] == 3: + value = op_data["global_tensor"][x, y, z].item() + elif op_data["dims"] == 2: + value = op_data["global_tensor"][x, y].item() + elif op_data["dims"] == 1: + value = op_data["global_tensor"][x].item() + + return jsonify({"value": value}) + except IndexError: + return jsonify({"error": "Coordinates out of bounds"}), 200 + else: + return jsonify({"error": "Global tensor data not found"}), 200 + + +def run_flask_with_cloudflared(): + cloudflared_port = 8000 # You can change this port if needed + tunnel_url = _run_cloudflared(cloudflared_port, 8001) # not too important + print(f"Cloudflare tunnel URL: {tunnel_url}") + app.run(port=cloudflared_port) + + +def launch(share=True): + print("Launching Triton viz tool") + if share: + flask_thread = threading.Thread(target=run_flask_with_cloudflared) + flask_thread.start() + + # Wait for the server to start + time.sleep(5) + + # Try to get the tunnel URL by making a request to the local server + try: + response = requests.get("http://localhost:8000") + print(f"Your app is now available at: {response.url}") + except requests.exceptions.RequestException: + print("Please wait for URL:") + else: + app.run(port=5001) + + +# This function can be called to stop the Flask server if needed +def stop_server(flask_thread): + # Implement a way to stop the Flask server + pass diff --git a/triton_viz/interpreter.py b/triton_viz/interpreter.py index a572e1e..953d56b 100644 --- a/triton_viz/interpreter.py +++ b/triton_viz/interpreter.py @@ -44,6 +44,19 @@ def _unpatch_lang(): class RecordBuilder: + def to_dict(self): + """Convert the recorded data to a dictionary format for JSON serialization.""" + return { + "launches": [ + { + "grid": launch.grid, + "tensors": [tensor.__dict__ for tensor in launch.tensors], + "records": [record.__dict__ for record in launch.records], + } + for launch in self._launches + ] + } + def __init__(self) -> None: self.reset() @@ -136,7 +149,9 @@ def _check_storage_contiguous(tensor): def _grid_executor_call(self, *args_dev, **kwargs): # Removes reserved keywords from kwargs + # kwargs has src_map and src src is the raw straight spurce code, src_map is a mapping between source code line and an op kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} + if kwargs.pop("warmup", False): return args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) @@ -162,6 +177,7 @@ def _grid_executor_call(self, *args_dev, **kwargs): arg.stride(), arg.shape, arg.element_size(), + arg, ) ) call_args[name] = ret @@ -291,15 +307,31 @@ def wrapper(lhs, rhs, op): def _create_dot(fn): @wraps(fn) - def wrapper(a, b, d, allow_tf32, maxNumImpreciseAcc): - ret = fn(a, b, d, allow_tf32, maxNumImpreciseAcc) + def wrapper(a, b, c, allow_tf32, max_num_imprecise_acc): dot_record = Dot( - input_shape=(a.data.shape, b.data.shape), - other_shape=d.data.shape, - output_shape=ret.data.shape, + input_shape=a.data.shape, + other_shape=b.data.shape, + output_shape=c.data.shape, + input_data=a.data.tolist(), + other_data=b.data.tolist(), ) record_builder.add_record(dot_record) - return ret + + def capture_intermediate(row, col, result): + dot_record.update_intermediate(row, col, float(result)) + + # Modify the original function to call capture_intermediate at each step + def modified_fn(a, b, c, allow_tf32, max_num_imprecise_acc): + for i in range(a.data.shape[0]): + for j in range(b.data.shape[1]): + A_row = a.data[i, :] + B_column = b.data[:, j] + result = np.dot(A_row, B_column) + capture_intermediate(i, j, result) + c.data[i, j] = result + + modified_fn(a, b, c, allow_tf32, max_num_imprecise_acc) + return c return wrapper @@ -377,3 +409,8 @@ def patch(): interpreter_builder.binary_op = old_binary_op interpreter_builder.create_dot = old_create_dot interpreter_builder.create_masked_store = old_create_masked_store + + +def get_recorded_data(): + """Return the recorded data in a format suitable for JSON serialization.""" + return record_builder.to_dict() diff --git a/triton_viz/static/gridblock.js b/triton_viz/static/gridblock.js new file mode 100644 index 0000000..6c999cc --- /dev/null +++ b/triton_viz/static/gridblock.js @@ -0,0 +1,241 @@ +import { createMatMulVisualization } from './matmul.js'; +import { createLoadVisualization } from './load.js'; +import { createStoreVisualization } from './store.js'; + +export class GridBlock { + constructor(x, y, width, height, gridX, gridY, gridZ, blockData, onClose, containerElement, canvas, drawFunction) { + this.rect = { x, y, width, height }; + this.gridPosition = { x: gridX, y: gridY, z: gridZ }; + this.blockData = blockData; + this.isHovered = false; + this.visualizationContainer = null; + this.visualizationCleanupFunction = null; + this.onClose = onClose; + this.containerElement = containerElement; + this.canvas = canvas; + this.drawFunction = drawFunction; + this.isDetailedViewVisible = false; + this.contentArea = null; + } + + + draw(ctx) { + // Draw background + ctx.fillStyle = this.isHovered ? '#4a4a4a' : '#323232'; + ctx.fillRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + + // Draw border + ctx.strokeStyle = '#000000'; + ctx.lineWidth = 1; + ctx.strokeRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + + // Draw grid position + ctx.fillStyle = '#c8c8c8'; + ctx.font = '12px Arial'; + ctx.textAlign = 'center'; + ctx.textBaseline = 'top'; + const posText = `${this.gridPosition.x},${this.gridPosition.y},${this.gridPosition.z}`; + ctx.fillText(posText, this.rect.x + this.rect.width / 2, this.rect.y + 2); + + // Draw operation types + ctx.font = '10px Arial'; + ctx.textAlign = 'left'; + ctx.textBaseline = 'top'; + this.blockData.forEach((op, index) => { + ctx.fillText(op.type, this.rect.x + 5, this.rect.y + 20 + index * 15); + }); + } + + isPointInside(x, y) { + return x >= this.rect.x && x <= this.rect.x + this.rect.width && + y >= this.rect.y && y <= this.rect.y + this.rect.height; + } + + handleMouseMove(x, y) { + const wasHovered = this.isHovered; + this.isHovered = this.isPointInside(x, y); + return wasHovered !== this.isHovered; // Return true if hover state changed + } + + + showDetailedView() { + if (this.isDetailedViewVisible) return; + + this.visualizationContainer = this.createVisualizationContainer(); + document.body.appendChild(this.visualizationContainer); + + const title = this.createTitle(); + const headerBar = this.createHeaderBar(); + this.contentArea = this.createContentArea(); + + this.visualizationContainer.appendChild(title); + this.visualizationContainer.appendChild(headerBar); + this.visualizationContainer.appendChild(this.contentArea); + + const closeButton = this.createCloseButton(); + this.visualizationContainer.appendChild(closeButton); + + this.isDetailedViewVisible = true; + this.canvas.style.display = 'none'; + this.containerElement.style.display = 'block'; + + // Display the first operation visualization after the content area is added to the DOM + if (this.blockData.length > 0) { + this.displayOpVisualization(this.blockData[0]); + } + } + + createVisualizationContainer() { + const container = document.createElement('div'); + Object.assign(container.style, { + position: 'fixed', + top: '0', + left: '0', + width: '100vw', + height: '100vh', + backgroundColor: '#1e1e28', + zIndex: '1000', + display: 'flex', + flexDirection: 'column', + color: '#ffffff' + }); + return container; + } + + createTitle() { + const title = document.createElement('h2'); + title.textContent = `Operations at (${this.gridPosition.x}, ${this.gridPosition.y}, ${this.gridPosition.z})`; + title.style.textAlign = 'center'; + title.style.margin = '10px 0'; + return title; + } + + createHeaderBar() { + const headerBar = document.createElement('div'); + Object.assign(headerBar.style, { + display: 'flex', + flexDirection: 'row', + backgroundColor: '#333', + padding: '5px', + overflowX: 'auto' + }); + + let currentSelectedTab = null; + this.blockData.forEach((op, index) => { + const opTab = this.createOperationTab(op, index === 0); + opTab.addEventListener('click', () => this.handleTabClick(opTab, op, currentSelectedTab)); + headerBar.appendChild(opTab); + if (index === 0) currentSelectedTab = opTab; + }); + + return headerBar; + } + + createOperationTab(op, isFirst) { + const opTab = document.createElement('button'); + opTab.textContent = op.type; + Object.assign(opTab.style, { + flex: '0 0 auto', + marginRight: '5px', + backgroundColor: isFirst ? '#555' : '#333', + color: '#fff', + border: 'none', + padding: '10px', + cursor: 'pointer' + }); + return opTab; + } + + handleTabClick(clickedTab, op, currentSelectedTab) { + if (currentSelectedTab) currentSelectedTab.style.backgroundColor = '#333'; + clickedTab.style.backgroundColor = '#555'; + this.displayOpVisualization(op); + } + + createContentArea() { + const contentArea = document.createElement('div'); + Object.assign(contentArea.style, { + flex: '1', + padding: '10px', + overflow: 'hidden', + position: 'relative' + }); + + if (this.blockData.length === 0) { + const noDataMsg = document.createElement('p'); + noDataMsg.textContent = 'No operation data available'; + noDataMsg.style.textAlign = 'center'; + contentArea.appendChild(noDataMsg); + } + + return contentArea; + } + displayOpVisualization(op) { + if (!this.contentArea) { + console.error('Content area is not initialized'); + return; + } + + if (this.visualizationCleanupFunction) { + this.visualizationCleanupFunction(); + this.visualizationCleanupFunction = null; + } + + this.contentArea.innerHTML = ''; + + switch (op.type) { + case 'Dot': + this.visualizationCleanupFunction = createMatMulVisualization(this.contentArea, op); + break; + case 'Load': + this.visualizationCleanupFunction = createLoadVisualization(this.contentArea, op); + break; + case 'Store': + this.visualizationCleanupFunction = createStoreVisualization(this.contentArea, op); + break; + default: + const unsupportedMsg = document.createElement('p'); + unsupportedMsg.textContent = `Visualization not supported for ${op.type} operation`; + unsupportedMsg.style.textAlign = 'center'; + this.contentArea.appendChild(unsupportedMsg); + } + } + + + createCloseButton() { + const closeButton = document.createElement('button'); + closeButton.textContent = 'Close'; + Object.assign(closeButton.style, { + position: 'fixed', + top: '10px', + right: '10px', + zIndex: '1001' + }); + closeButton.addEventListener('click', () => this.hideDetailedView()); + return closeButton; + } + + hideDetailedView() { + if (!this.isDetailedViewVisible) return; + + if (this.visualizationCleanupFunction) { + this.visualizationCleanupFunction(); + this.visualizationCleanupFunction = null; + } + + if (this.visualizationContainer) { + document.body.removeChild(this.visualizationContainer); + this.visualizationContainer = null; + } + + this.isDetailedViewVisible = false; + + if (this.onClose) { + this.onClose(); + } + + this.canvas.style.display = 'block'; + this.containerElement.style.display = 'none'; + this.drawFunction(); + } +} diff --git a/triton_viz/static/images/accelerator-memory-hierarchy.png b/triton_viz/static/images/accelerator-memory-hierarchy.png new file mode 100644 index 0000000..3ca48d1 Binary files /dev/null and b/triton_viz/static/images/accelerator-memory-hierarchy.png differ diff --git a/triton_viz/static/images/tensor-slicing-diagram.png b/triton_viz/static/images/tensor-slicing-diagram.png new file mode 100644 index 0000000..d6a300f Binary files /dev/null and b/triton_viz/static/images/tensor-slicing-diagram.png differ diff --git a/triton_viz/static/infoPopup.js b/triton_viz/static/infoPopup.js new file mode 100644 index 0000000..0693a77 --- /dev/null +++ b/triton_viz/static/infoPopup.js @@ -0,0 +1,211 @@ +const infoContent = [ + { + title: "Visualization Tool Overview", + content: ` +
This tool allows you to explore and visualize tensor operations.
+If you're new to Triton or need to brush up on your skills, check out Triton Puzzles. It's a collection of Triton problems designed to get you up to speed:
+ +These puzzles provide hands-on experience with Triton concepts, which can help you better understand the visualizations in this tool.
+ ` + }, + { + title: "Using the Sliders", + content: ` +The sliders on the right side of the screen allow you to filter the grid blocks:
+Set a slider to -1 to show all blocks along that dimension.
+ ` + }, + { + title: "Detailed View", + content: ` +Clicking on a grid block will open a detailed view of the operations for that block.
+In the detailed view, you can:
+Triton operates on slices of tensors rather than entire tensors. This is a key concept in understanding how Triton kernels process data efficiently.
+ +Inside Triton kernels, tl.load
is used to load slices of the global tensor. For example:
a = tl.load(a_ptrs, mask=offs_k[None, :])
+ This operation loads only the necessary slice of data, optimizing memory access and computation.
+ +tl.load
and tl.store
operations in Triton handle the slicing automatically.Diagram showing how a global PyTorch tensor is sliced for processing in Triton kernels.
+ ` + }, + { + title: "Memory Hierarchy in AI Accelerators and Triton Operations", + content: ` +Understanding the memory hierarchy in AI accelerators is crucial for optimizing Triton kernels. While this example uses a GPU, it's important to note that Triton is designed for various AI accelerators, not just GPUs.
+ +Triton provides mechanisms to efficiently move data between these memory types across various accelerators:
+tl.load
: Generally used to load a specified portion of the global tensor into shared memory*. This operation can significantly speed up subsequent computations.tl.store
: Used to write processed data back to global memory.* It's important to note that Triton doesn't always transfer tensor slices to shared memory. This operation is performed only when it makes sense for performance optimization.
+ +tl.load
and tl.store
operations abstract the complexity of memory transfers, making code portable across different accelerator types.Diagram showing the memory hierarchy of an RTX 4090 GPU as an example of AI accelerator architecture. Global memory (magenta) and on-chip memory including shared memory* (red) are highlighted. Note that this structure is similar across many AI accelerators.
+ ` + } +]; + +let currentPage = 0; + +export function createInfoPopup() { + const infoPopup = document.createElement('div'); + infoPopup.style.display = 'none'; + infoPopup.style.position = 'fixed'; + infoPopup.style.top = '5%'; + infoPopup.style.left = '5%'; + infoPopup.style.width = '85%'; + infoPopup.style.height = '75%'; + infoPopup.style.backgroundColor = '#1a1a1a'; + infoPopup.style.color = '#fff'; + infoPopup.style.padding = '30px'; + infoPopup.style.borderRadius = '15px'; + infoPopup.style.boxShadow = '0 4px 6px rgba(0, 0, 0, 0.5)'; + infoPopup.style.zIndex = '1000'; + infoPopup.style.overflow = 'auto'; + + const closeButton = createButton('×', () => { infoPopup.style.display = 'none'; }); + closeButton.style.position = 'absolute'; + closeButton.style.top = '20px'; + closeButton.style.right = '20px'; + closeButton.style.fontSize = '36px'; + closeButton.style.color = '#fff'; + + const content = document.createElement('div'); + content.id = 'info-content'; + content.style.fontSize = '18px'; + content.style.lineHeight = '1.6'; + + const navigation = document.createElement('div'); + navigation.style.display = 'flex'; + navigation.style.justifyContent = 'space-between'; + navigation.style.marginTop = '30px'; + + const prevButton = createNavButton('← Previous', () => navigatePage(-1)); + const nextButton = createNavButton('Next →', () => navigatePage(1)); + + navigation.appendChild(prevButton); + navigation.appendChild(nextButton); + + infoPopup.appendChild(closeButton); + infoPopup.appendChild(content); + infoPopup.appendChild(navigation); + + document.body.appendChild(infoPopup); + + updateContent(); + + return infoPopup; +} + +function createButton(text, onClick) { + const button = document.createElement('button'); + button.innerHTML = text; + button.style.border = 'none'; + button.style.background = 'none'; + button.style.cursor = 'pointer'; + button.style.color = '#fff'; + button.onclick = onClick; + return button; +} + +function createNavButton(text, onClick) { + const button = createButton(text, onClick); + button.style.fontSize = '20px'; + button.style.padding = '15px 25px'; + button.style.backgroundColor = '#333'; + button.style.borderRadius = '8px'; + button.style.transition = 'background-color 0.3s'; + button.onmouseover = () => { button.style.backgroundColor = '#444'; }; + button.onmouseout = () => { button.style.backgroundColor = '#333'; }; + return button; +} + +function navigatePage(direction) { + currentPage += direction; + if (currentPage < 0) currentPage = infoContent.length - 1; + if (currentPage >= infoContent.length) currentPage = 0; + updateContent(); +} + +function updateContent() { + const contentDiv = document.getElementById('info-content'); + const pageContent = infoContent[currentPage]; + contentDiv.innerHTML = ` +X: ${x + 1}
+Y: ${y + 1}
+Z: ${z + 1}
+Dimensions: ${dims.join(' x ')}
+Value: ${value !== undefined ? value : 'Loading...'}
+ `; + } + + function updateInfoPanel(globalCoord, sliceCoord, index) { + sideMenu.innerHTML = ` +Global Coords: (${globalCoord.join(', ')})
+Slice Coords: (${sliceCoord.join(', ')})
+Progress: ${index + 1}/${op.global_coords.length}
+ `; + } + + function createSideMenu(container) { + const menu = document.createElement('div'); + menu.style.position = 'absolute'; + menu.style.top = '10px'; + menu.style.right = '10px'; + menu.style.width = '200px'; + menu.style.padding = '10px'; + menu.style.backgroundColor = 'rgba(0, 0, 0, 0.7)'; + menu.style.color = 'white'; + menu.style.fontFamily = 'Arial, sans-serif'; + menu.style.fontSize = '14px'; + menu.style.borderRadius = '5px'; + container.appendChild(menu); + return menu; + } + + function addLabels(scene, globalTensor, sliceTensor) { + addLabel(scene, "Global Tensor", globalTensor.position); + addLabel(scene, "Slice Tensor", sliceTensor.position); + } + + function addLabel(scene, text, position) { + const canvas = document.createElement('canvas'); + const context = canvas.getContext('2d'); + context.font = 'Bold 24px Arial'; + context.fillStyle = 'white'; + context.fillText(text, 0, 24); + + const texture = new THREE.CanvasTexture(canvas); + const material = new THREE.SpriteMaterial({ map: texture }); + const sprite = new THREE.Sprite(material); + sprite.position.set(position.x, position.y + 2, position.z); + sprite.scale.set(4, 2, 1); + scene.add(sprite); + } + +} diff --git a/triton_viz/static/load_utils.js b/triton_viz/static/load_utils.js new file mode 100644 index 0000000..2729d04 --- /dev/null +++ b/triton_viz/static/load_utils.js @@ -0,0 +1,200 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; + +export const CUBE_SIZE = 0.2; +export const GAP = 0.05; +export const COLOR_HOVER = new THREE.Color(1.0, 1.0, 0.0); // Yellow for hover effect +export const COLOR_EDGE = new THREE.Color(0.5, 0.5, 0.5); // Gray (for cube edges) + +const COLOR_SLICE = new THREE.Color(0.0, 0.7, 1.0); // Cyan (starting color for global slice) + +export function setupScene(container, backgroundColor = 0x000000) { + const scene = new THREE.Scene(); + scene.background = new THREE.Color(backgroundColor); + const camera = new THREE.PerspectiveCamera(45, container.clientWidth / container.clientHeight, 0.1, 1000); + const renderer = new THREE.WebGLRenderer({ antialias: true }); + renderer.setSize(container.clientWidth, container.clientHeight); + container.appendChild(renderer.domElement); + + const ambientLight = new THREE.AmbientLight(0xffffff, 0.5); + scene.add(ambientLight); + const directionalLight = new THREE.DirectionalLight(0xffffff, 0.5); + directionalLight.position.set(10, 10, 10); + scene.add(directionalLight); + + return { scene, camera, renderer }; +} + +export function setupGeometries() { + const cubeGeometry = new THREE.BoxGeometry(CUBE_SIZE, CUBE_SIZE, CUBE_SIZE); + const edgesGeometry = new THREE.EdgesGeometry(cubeGeometry); + const lineMaterial = new THREE.LineBasicMaterial({ color: COLOR_EDGE }); + return { cubeGeometry, edgesGeometry, lineMaterial }; +} + +export function createCube(color, tensorName, x, y, z, cubeGeometry, edgesGeometry, lineMaterial) { + const cubeMaterial = new THREE.MeshPhongMaterial({ color: color }); + const cube = new THREE.Mesh(cubeGeometry, cubeMaterial); + const edges = new THREE.LineSegments(edgesGeometry, lineMaterial); + cube.add(edges); + + const hoverGeometry = new THREE.BoxGeometry(CUBE_SIZE * 1.05, CUBE_SIZE * 1.05, CUBE_SIZE * 1.05); + const hoverEdgesGeometry = new THREE.EdgesGeometry(hoverGeometry); + const hoverOutline = new THREE.LineSegments(hoverEdgesGeometry, new THREE.LineBasicMaterial({ color: COLOR_HOVER })); + hoverOutline.visible = false; + hoverOutline.name = 'hoverOutline'; + cube.add(hoverOutline); + + // Add custom properties to store tensor coordinates + cube.userData.tensor0 = z; + cube.userData.tensor1 = y; + cube.userData.tensor2 = x; + cube.userData.tensorName = tensorName; + + cube.name = `${tensorName}_cube_${x}_${y}_${z}`; + + return cube; +} + +export function createTensor(shape, coords, color, tensorName, cubeGeometry, edgesGeometry, lineMaterial) { + console.log(`Creating ${tensorName} tensor:`, shape, coords); + const tensor = new THREE.Group(); + let [width, height, depth] = shape; + depth = depth || 1; + height = height || 1; + + if (tensorName === 'Global') { + console.log(`Creating global tensor with dimensions: ${width}x${height}x${depth}`); + for (let z = 0; z < depth; z++) { + for (let y = 0; y < height; y++) { + for (let x = 0; x < width; x++) { + const cube = createCube(color, tensorName, x, y, z, cubeGeometry, edgesGeometry, lineMaterial); + cube.position.set( + x * (CUBE_SIZE + GAP), + -y * (CUBE_SIZE + GAP), + -z * (CUBE_SIZE + GAP) + ); + tensor.add(cube); + } + } + } + + console.log(`Highlighting ${coords.length} coordinates in global tensor`); + coords.forEach(([x, y, z]) => { + const cube = tensor.children.find(c => + c.userData.tensor0 === x && c.userData.tensor1 === y && c.userData.tensor2 === z + ); + if (cube) { + cube.material.color.set(COLOR_SLICE); + console.log(`Highlighted cube at (${x}, ${y}, ${z})`); + } else { + console.warn(`Could not find cube at (${x}, ${y}, ${z})`); + } + }); + } else { + console.log(`Creating slice tensor with ${coords.length} coordinates`); + coords.forEach(([x, y, z]) => { + const cube = createCube(color, tensorName, x, y, z, cubeGeometry, edgesGeometry, lineMaterial); + cube.position.set( + x * (CUBE_SIZE + GAP), + -y * (CUBE_SIZE + GAP), + -z * (CUBE_SIZE + GAP) + ); + tensor.add(cube); + }); + } + + console.log(`Created ${tensorName} tensor with ${tensor.children.length} cubes`); + return tensor; +} + +export function calculateTensorSize(shape) { + const [width, height, depth] = shape; + return new THREE.Vector3( + width * (CUBE_SIZE + GAP), + height * (CUBE_SIZE + GAP), + depth * (CUBE_SIZE + GAP) + ); +} + +export function interpolateColor(color1, color2, factor) { + return new THREE.Color().lerpColors(color1, color2, factor); +} + +export function updateCubeColor(tensor, coord, startColor, endColor, factor) { + const cube = tensor.children.find(c => + c.tensor0 === coord[0] && c.tensor1 === coord[1] && c.tensor2 === coord[2] + ); + if (cube) { + cube.material.color.copy(interpolateColor(startColor, endColor, factor)); + } +} + +export function setupCamera(scene, camera) { + const box = new THREE.Box3().setFromObject(scene); + const center = box.getCenter(new THREE.Vector3()); + const size = box.getSize(new THREE.Vector3()); + const maxDim = Math.max(size.x, size.y, size.z); + const fov = camera.fov * (Math.PI / 180); + let cameraZ = Math.abs(maxDim / 2 / Math.tan(fov / 2)); + cameraZ *= 1.5; + + camera.position.set(center.x, center.y, center.z + cameraZ); + camera.lookAt(center); + + return { center, cameraZ }; +} + +export function setupEventListeners(containerElement, camera, renderer, onMouseMove, onKeyDown) { + window.addEventListener('resize', () => { + camera.aspect = containerElement.clientWidth / containerElement.clientHeight; + camera.updateProjectionMatrix(); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + }); + containerElement.addEventListener('mousemove', onMouseMove); + window.addEventListener('keydown', onKeyDown); +} + +export function cameraControls(camera, cameraRotation) { + const PAN_SPEED = 0.1; + const TILT_SPEED = 0.02; + const ZOOM_SPEED = 0.5; + + return function onKeyDown(event) { + switch (event.key.toLowerCase()) { + case 'w': + camera.position.y += PAN_SPEED; + break; + case 's': + camera.position.y -= PAN_SPEED; + break; + case 'a': + camera.position.x -= PAN_SPEED; + break; + case 'd': + camera.position.x += PAN_SPEED; + break; + case 'arrowup': + cameraRotation.x -= TILT_SPEED; + break; + case 'arrowdown': + cameraRotation.x += TILT_SPEED; + break; + case 'arrowleft': + cameraRotation.y -= TILT_SPEED; + break; + case 'arrowright': + cameraRotation.y += TILT_SPEED; + break; + case 'o': + camera.position.z += ZOOM_SPEED; + break; + case 'p': + camera.position.z -= ZOOM_SPEED; + break; + case ' ': + break; + } + camera.setRotationFromEuler(cameraRotation); + camera.updateProjectionMatrix(); + }; +} diff --git a/triton_viz/static/matmul.js b/triton_viz/static/matmul.js new file mode 100644 index 0000000..d87d78b --- /dev/null +++ b/triton_viz/static/matmul.js @@ -0,0 +1,360 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; + +export function createMatMulVisualization(containerElement, op) { + const { input_shape, other_shape, output_shape } = op; + console.log(op.uuid) + fetch('/api/setop', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid: op.uuid }), + }) + .then(response => response.json()) + .then(data => console.log('Set current op:', data)) + .catch((error) => console.error('Error:', error)); + let currentStep = 0; + const totalSteps = input_shape[1]; + let frame = 0; + + + const sideMenu = document.createElement('div'); + sideMenu.style.position = 'absolute'; + sideMenu.style.top = '10px'; + sideMenu.style.right = '10px'; + sideMenu.style.width = '200px'; + sideMenu.style.padding = '10px'; + sideMenu.style.backgroundColor = 'rgba(0, 0, 0, 0.7)'; + sideMenu.style.color = 'white'; + sideMenu.style.fontFamily = 'Arial, sans-serif'; + sideMenu.style.fontSize = '14px'; + sideMenu.style.borderRadius = '5px'; + containerElement.appendChild(sideMenu); + let hoveredCube = null; + + + + + async function getElementValue( matrixName, row, col) { + let uuid = op.uuid; + const response = await fetch('/api/getValue', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid, matrixName, row, col, currentStep }), + }); + return await response.json(); + } + + + function updateSideMenu(matrix, x, y) { + if (!matrix) { + sideMenu.innerHTML = ''; + return; + } + + let matrixName; + let dims; + if (matrix === matrixA) { + matrixName = 'A'; + dims = input_shape; + } else if (matrix === matrixB) { + matrixName = 'B'; + dims = other_shape; + } else if (matrix === matrixC) { + matrixName = 'C'; + dims = output_shape; + } else { + sideMenu.innerHTML = ''; + return; + } + console.log(matrixName, "x:", (x + 1), "y:", (y + 1)); + sideMenu.innerHTML = ` +Row: ${y + 1}
+Column: ${x + 1}
+Dimensions: ${dims[0]} x ${dims[1]}
+ `; + } + + const raycaster = new THREE.Raycaster(); + const mouse = new THREE.Vector2(); + + + async function onMouseMove(event) { + mouse.x = (event.clientX / containerElement.clientWidth) * 2 - 1; + mouse.y = -(event.clientY / containerElement.clientHeight) * 2 + 1; + + raycaster.setFromCamera(mouse, camera); + + const allMatrixChildren = [ + ...(matrixA ? matrixA.children : []), + ...(matrixB ? matrixB.children : []), + ...(matrixC ? matrixC.children : []) + ]; + + const intersects = raycaster.intersectObjects(allMatrixChildren, true); + + if (hoveredCube) { + hoveredCube.getObjectByName('hoverOutline').visible = false; + hoveredCube = null; + } + + if (intersects.length > 0) { + // Find the actual cube (parent of the intersected object) + hoveredCube = intersects[0].object; + while (hoveredCube && !hoveredCube.matrixName) { + hoveredCube = hoveredCube.parent; + } + + if (hoveredCube) { + const hoverOutline = hoveredCube.getObjectByName('hoverOutline'); + if (hoverOutline) { + hoverOutline.visible = true; + } + const res = await getElementValue(hoveredCube.matrixName, hoveredCube.matrixRow, hoveredCube.matrixCol); + // Log the matrix name, row, and column of the hovered cube + console.log( + // `Matrix: ${hoveredCube.matrixName}, ` + + // `Row: ${hoveredCube.matrixRow + 1}, ` + + // `Column: ${hoveredCube.matrixCol + 1}`+ + `Value: ${res.value}` + ); + + updateSideMenu(hoveredCube.matrixName, hoveredCube.matrixRow, hoveredCube.matrixCol); + } + } else { + updateSideMenu(null); + } + } + + + + + + const CUBE_SIZE = 0.2; + const GAP = 0.05; + + const COLOR_A = new THREE.Color(0.53, 0.81, 0.98); + const COLOR_B = new THREE.Color(1.0, 0.65, 0.0); + const COLOR_C = new THREE.Color(1.0, 1.0, 1.0); + const COLOR_HIGHLIGHT = new THREE.Color(0.0, 0.0, 1.0); + const COLOR_FILLED = new THREE.Color(0.0, 0.0, 1.0); + const COLOR_BACKGROUND = new THREE.Color(0.0, 0.0, 0.0); + const COLOR_EDGE = new THREE.Color(0.3, 0.3, 0.3); + const COLOR_HOVER = new THREE.Color(1.0, 1.0, 0.0); + + const scene = new THREE.Scene(); + scene.background = COLOR_BACKGROUND; + const camera = new THREE.PerspectiveCamera(45, containerElement.clientWidth / containerElement.clientHeight, 0.1, 1000); + + const renderer = new THREE.WebGLRenderer({ antialias: true }); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + containerElement.appendChild(renderer.domElement); + + const ambientLight = new THREE.AmbientLight(0xffffff, 0.5); + scene.add(ambientLight); + const directionalLight = new THREE.DirectionalLight(0xffffff, 0.5); + directionalLight.position.set(10, 10, 10); + scene.add(directionalLight); + + const cubeGeometry = new THREE.BoxGeometry(CUBE_SIZE, CUBE_SIZE, CUBE_SIZE); + const edgesGeometry = new THREE.EdgesGeometry(cubeGeometry); + const lineMaterial = new THREE.LineBasicMaterial({ color: COLOR_EDGE }); + + function createCube(color, matrixName, i, j) { + const cubeMaterial = new THREE.MeshPhongMaterial({ color: color }); + const cube = new THREE.Mesh(cubeGeometry, cubeMaterial); + const edges = new THREE.LineSegments(edgesGeometry, lineMaterial); + cube.add(edges); + + const hoverGeometry = new THREE.BoxGeometry(CUBE_SIZE * 1.05, CUBE_SIZE * 1.05, CUBE_SIZE * 1.05); + const hoverEdgesGeometry = new THREE.EdgesGeometry(hoverGeometry); + const hoverOutline = new THREE.LineSegments(hoverEdgesGeometry, new THREE.LineBasicMaterial({ color: COLOR_HOVER })); + hoverOutline.visible = false; + hoverOutline.name = 'hoverOutline'; + cube.add(hoverOutline); + + cube.name = `${matrixName}_cube_${i}_${j}`; + cube.matrixName = matrixName; + cube.matrixRow = i; + cube.matrixCol = j; + + return cube; + } + + function createMatrix(dimensions, position, color, matrixName) { + const matrix = new THREE.Group(); + matrix.userData.dimensions = dimensions; + for (let i = 0; i < dimensions[0]; i++) { + for (let j = 0; j < dimensions[1]; j++) { + const cube = createCube(color, matrixName, i, j); + cube.position.set( + position.x + j * (CUBE_SIZE + GAP), + position.y - i * (CUBE_SIZE + GAP), + position.z + ); + matrix.add(cube); + } + } + return matrix; + } + + const matrixA = createMatrix(input_shape, new THREE.Vector3(-10, 10, 0), COLOR_A, 'A'); + const matrixB = createMatrix(other_shape, new THREE.Vector3(0, 10, 0), COLOR_B, 'B'); + const matrixC = createMatrix(output_shape, new THREE.Vector3(-5, -4, 0), COLOR_C, 'C'); + + scene.add(matrixA); + scene.add(matrixB); + scene.add(matrixC); + + const center = new THREE.Vector3(); + const size = new THREE.Vector3(); + const box = new THREE.Box3().setFromObject(scene); + box.getCenter(center); + box.getSize(size); + const maxDim = Math.max(size.x, size.y, size.z); + const fov = camera.fov * (Math.PI / 180); + let cameraZ = Math.abs(maxDim / 2 / Math.tan(fov / 2)); + cameraZ *= 1.5; + + camera.position.set(center.x, center.y, center.z + cameraZ); + camera.lookAt(center); + + let isPaused = false; + + const totalFrames = input_shape[0] * other_shape[1]; + + function highlightCubes(matrix, indices, highlightColor) { + indices.forEach(([i, j]) => { + if (i >= 0 && i < matrix.userData.dimensions[0] && j >= 0 && j < matrix.userData.dimensions[1]) { + const index = i * matrix.userData.dimensions[1] + j; + if (index < matrix.children.length) { + matrix.children[index].material.color.copy(highlightColor); + + } + } + }); + } + + function resetColors() { + matrixA.children.forEach(cube => cube.material.color.copy(COLOR_A)); + matrixB.children.forEach(cube => cube.material.color.copy(COLOR_B)); + } + + function animate() { + requestAnimationFrame(animate); + + if (!isPaused && frame < totalFrames) { + resetColors(); + + const row = Math.floor(frame / other_shape[1]); + const col = frame % other_shape[1]; + currentStep = frame % totalSteps + 1; + + const highlightA = Array.from({ length: input_shape[1] }, (_, i) => [row, i]); + const highlightB = Array.from({ length: other_shape[0] }, (_, i) => [i, col]); + const highlightC = [[row, col]]; + + highlightCubes(matrixA, highlightA, COLOR_HIGHLIGHT); + highlightCubes(matrixB, highlightB, COLOR_HIGHLIGHT); + highlightCubes(matrixC, highlightC, COLOR_FILLED); + + frame++; + } + + renderer.render(scene, camera); + } + + function onResize() { + camera.aspect = containerElement.clientWidth / containerElement.clientHeight; + camera.updateProjectionMatrix(); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + } + + const controlPanel = document.createElement('div'); + controlPanel.style.position = 'absolute'; + controlPanel.style.bottom = '10px'; + controlPanel.style.left = '10px'; + controlPanel.style.display = 'flex'; + controlPanel.style.gap = '10px'; + + const playPauseButton = document.createElement('button'); + playPauseButton.textContent = 'Play/Pause'; + playPauseButton.addEventListener('click', () => { + isPaused = !isPaused; + }); + + const resetButton = document.createElement('button'); + resetButton.textContent = 'Reset'; + resetButton.addEventListener('click', () => { + frame = 0; + resetColors(); + }); + + controlPanel.appendChild(playPauseButton); + controlPanel.appendChild(resetButton); + + containerElement.appendChild(controlPanel); + + const PAN_SPEED = 0.1; + const TILT_SPEED = 0.02; + const ZOOM_SPEED = 0.5; + let cameraRotation = new THREE.Euler(0, 0, 0, 'YXZ'); + + function onKeyDown(event) { + switch (event.key.toLowerCase()) { + case 'w': + camera.position.y += PAN_SPEED; + break; + case 's': + camera.position.y -= PAN_SPEED; + break; + case 'a': + camera.position.x -= PAN_SPEED; + break; + case 'd': + camera.position.x += PAN_SPEED; + break; + case 'arrowup': + cameraRotation.x -= TILT_SPEED; + break; + case 'arrowdown': + cameraRotation.x += TILT_SPEED; + break; + case 'arrowleft': + cameraRotation.y -= TILT_SPEED; + break; + case 'arrowright': + cameraRotation.y += TILT_SPEED; + break; + case 'o': + camera.position.z += ZOOM_SPEED; + break; + case 'p': + camera.position.z -= ZOOM_SPEED; + break; + } + camera.setRotationFromEuler(cameraRotation); + camera.updateProjectionMatrix(); + } + + window.addEventListener('resize', onResize); + window.addEventListener('keydown', onKeyDown); + containerElement.addEventListener('mousemove', onMouseMove); + + animate(); + + + function cleanup() { + window.removeEventListener('resize', onResize); + window.removeEventListener('keydown', onKeyDown); + containerElement.removeEventListener('mousemove', onMouseMove); + containerElement.innerHTML = ''; + renderer.dispose(); + scene.clear(); + } + + return cleanup; +} diff --git a/triton_viz/static/store-utils.js b/triton_viz/static/store-utils.js new file mode 100644 index 0000000..d87d78b --- /dev/null +++ b/triton_viz/static/store-utils.js @@ -0,0 +1,360 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; + +export function createMatMulVisualization(containerElement, op) { + const { input_shape, other_shape, output_shape } = op; + console.log(op.uuid) + fetch('/api/setop', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid: op.uuid }), + }) + .then(response => response.json()) + .then(data => console.log('Set current op:', data)) + .catch((error) => console.error('Error:', error)); + let currentStep = 0; + const totalSteps = input_shape[1]; + let frame = 0; + + + const sideMenu = document.createElement('div'); + sideMenu.style.position = 'absolute'; + sideMenu.style.top = '10px'; + sideMenu.style.right = '10px'; + sideMenu.style.width = '200px'; + sideMenu.style.padding = '10px'; + sideMenu.style.backgroundColor = 'rgba(0, 0, 0, 0.7)'; + sideMenu.style.color = 'white'; + sideMenu.style.fontFamily = 'Arial, sans-serif'; + sideMenu.style.fontSize = '14px'; + sideMenu.style.borderRadius = '5px'; + containerElement.appendChild(sideMenu); + let hoveredCube = null; + + + + + async function getElementValue( matrixName, row, col) { + let uuid = op.uuid; + const response = await fetch('/api/getValue', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid, matrixName, row, col, currentStep }), + }); + return await response.json(); + } + + + function updateSideMenu(matrix, x, y) { + if (!matrix) { + sideMenu.innerHTML = ''; + return; + } + + let matrixName; + let dims; + if (matrix === matrixA) { + matrixName = 'A'; + dims = input_shape; + } else if (matrix === matrixB) { + matrixName = 'B'; + dims = other_shape; + } else if (matrix === matrixC) { + matrixName = 'C'; + dims = output_shape; + } else { + sideMenu.innerHTML = ''; + return; + } + console.log(matrixName, "x:", (x + 1), "y:", (y + 1)); + sideMenu.innerHTML = ` +Row: ${y + 1}
+Column: ${x + 1}
+Dimensions: ${dims[0]} x ${dims[1]}
+ `; + } + + const raycaster = new THREE.Raycaster(); + const mouse = new THREE.Vector2(); + + + async function onMouseMove(event) { + mouse.x = (event.clientX / containerElement.clientWidth) * 2 - 1; + mouse.y = -(event.clientY / containerElement.clientHeight) * 2 + 1; + + raycaster.setFromCamera(mouse, camera); + + const allMatrixChildren = [ + ...(matrixA ? matrixA.children : []), + ...(matrixB ? matrixB.children : []), + ...(matrixC ? matrixC.children : []) + ]; + + const intersects = raycaster.intersectObjects(allMatrixChildren, true); + + if (hoveredCube) { + hoveredCube.getObjectByName('hoverOutline').visible = false; + hoveredCube = null; + } + + if (intersects.length > 0) { + // Find the actual cube (parent of the intersected object) + hoveredCube = intersects[0].object; + while (hoveredCube && !hoveredCube.matrixName) { + hoveredCube = hoveredCube.parent; + } + + if (hoveredCube) { + const hoverOutline = hoveredCube.getObjectByName('hoverOutline'); + if (hoverOutline) { + hoverOutline.visible = true; + } + const res = await getElementValue(hoveredCube.matrixName, hoveredCube.matrixRow, hoveredCube.matrixCol); + // Log the matrix name, row, and column of the hovered cube + console.log( + // `Matrix: ${hoveredCube.matrixName}, ` + + // `Row: ${hoveredCube.matrixRow + 1}, ` + + // `Column: ${hoveredCube.matrixCol + 1}`+ + `Value: ${res.value}` + ); + + updateSideMenu(hoveredCube.matrixName, hoveredCube.matrixRow, hoveredCube.matrixCol); + } + } else { + updateSideMenu(null); + } + } + + + + + + const CUBE_SIZE = 0.2; + const GAP = 0.05; + + const COLOR_A = new THREE.Color(0.53, 0.81, 0.98); + const COLOR_B = new THREE.Color(1.0, 0.65, 0.0); + const COLOR_C = new THREE.Color(1.0, 1.0, 1.0); + const COLOR_HIGHLIGHT = new THREE.Color(0.0, 0.0, 1.0); + const COLOR_FILLED = new THREE.Color(0.0, 0.0, 1.0); + const COLOR_BACKGROUND = new THREE.Color(0.0, 0.0, 0.0); + const COLOR_EDGE = new THREE.Color(0.3, 0.3, 0.3); + const COLOR_HOVER = new THREE.Color(1.0, 1.0, 0.0); + + const scene = new THREE.Scene(); + scene.background = COLOR_BACKGROUND; + const camera = new THREE.PerspectiveCamera(45, containerElement.clientWidth / containerElement.clientHeight, 0.1, 1000); + + const renderer = new THREE.WebGLRenderer({ antialias: true }); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + containerElement.appendChild(renderer.domElement); + + const ambientLight = new THREE.AmbientLight(0xffffff, 0.5); + scene.add(ambientLight); + const directionalLight = new THREE.DirectionalLight(0xffffff, 0.5); + directionalLight.position.set(10, 10, 10); + scene.add(directionalLight); + + const cubeGeometry = new THREE.BoxGeometry(CUBE_SIZE, CUBE_SIZE, CUBE_SIZE); + const edgesGeometry = new THREE.EdgesGeometry(cubeGeometry); + const lineMaterial = new THREE.LineBasicMaterial({ color: COLOR_EDGE }); + + function createCube(color, matrixName, i, j) { + const cubeMaterial = new THREE.MeshPhongMaterial({ color: color }); + const cube = new THREE.Mesh(cubeGeometry, cubeMaterial); + const edges = new THREE.LineSegments(edgesGeometry, lineMaterial); + cube.add(edges); + + const hoverGeometry = new THREE.BoxGeometry(CUBE_SIZE * 1.05, CUBE_SIZE * 1.05, CUBE_SIZE * 1.05); + const hoverEdgesGeometry = new THREE.EdgesGeometry(hoverGeometry); + const hoverOutline = new THREE.LineSegments(hoverEdgesGeometry, new THREE.LineBasicMaterial({ color: COLOR_HOVER })); + hoverOutline.visible = false; + hoverOutline.name = 'hoverOutline'; + cube.add(hoverOutline); + + cube.name = `${matrixName}_cube_${i}_${j}`; + cube.matrixName = matrixName; + cube.matrixRow = i; + cube.matrixCol = j; + + return cube; + } + + function createMatrix(dimensions, position, color, matrixName) { + const matrix = new THREE.Group(); + matrix.userData.dimensions = dimensions; + for (let i = 0; i < dimensions[0]; i++) { + for (let j = 0; j < dimensions[1]; j++) { + const cube = createCube(color, matrixName, i, j); + cube.position.set( + position.x + j * (CUBE_SIZE + GAP), + position.y - i * (CUBE_SIZE + GAP), + position.z + ); + matrix.add(cube); + } + } + return matrix; + } + + const matrixA = createMatrix(input_shape, new THREE.Vector3(-10, 10, 0), COLOR_A, 'A'); + const matrixB = createMatrix(other_shape, new THREE.Vector3(0, 10, 0), COLOR_B, 'B'); + const matrixC = createMatrix(output_shape, new THREE.Vector3(-5, -4, 0), COLOR_C, 'C'); + + scene.add(matrixA); + scene.add(matrixB); + scene.add(matrixC); + + const center = new THREE.Vector3(); + const size = new THREE.Vector3(); + const box = new THREE.Box3().setFromObject(scene); + box.getCenter(center); + box.getSize(size); + const maxDim = Math.max(size.x, size.y, size.z); + const fov = camera.fov * (Math.PI / 180); + let cameraZ = Math.abs(maxDim / 2 / Math.tan(fov / 2)); + cameraZ *= 1.5; + + camera.position.set(center.x, center.y, center.z + cameraZ); + camera.lookAt(center); + + let isPaused = false; + + const totalFrames = input_shape[0] * other_shape[1]; + + function highlightCubes(matrix, indices, highlightColor) { + indices.forEach(([i, j]) => { + if (i >= 0 && i < matrix.userData.dimensions[0] && j >= 0 && j < matrix.userData.dimensions[1]) { + const index = i * matrix.userData.dimensions[1] + j; + if (index < matrix.children.length) { + matrix.children[index].material.color.copy(highlightColor); + + } + } + }); + } + + function resetColors() { + matrixA.children.forEach(cube => cube.material.color.copy(COLOR_A)); + matrixB.children.forEach(cube => cube.material.color.copy(COLOR_B)); + } + + function animate() { + requestAnimationFrame(animate); + + if (!isPaused && frame < totalFrames) { + resetColors(); + + const row = Math.floor(frame / other_shape[1]); + const col = frame % other_shape[1]; + currentStep = frame % totalSteps + 1; + + const highlightA = Array.from({ length: input_shape[1] }, (_, i) => [row, i]); + const highlightB = Array.from({ length: other_shape[0] }, (_, i) => [i, col]); + const highlightC = [[row, col]]; + + highlightCubes(matrixA, highlightA, COLOR_HIGHLIGHT); + highlightCubes(matrixB, highlightB, COLOR_HIGHLIGHT); + highlightCubes(matrixC, highlightC, COLOR_FILLED); + + frame++; + } + + renderer.render(scene, camera); + } + + function onResize() { + camera.aspect = containerElement.clientWidth / containerElement.clientHeight; + camera.updateProjectionMatrix(); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + } + + const controlPanel = document.createElement('div'); + controlPanel.style.position = 'absolute'; + controlPanel.style.bottom = '10px'; + controlPanel.style.left = '10px'; + controlPanel.style.display = 'flex'; + controlPanel.style.gap = '10px'; + + const playPauseButton = document.createElement('button'); + playPauseButton.textContent = 'Play/Pause'; + playPauseButton.addEventListener('click', () => { + isPaused = !isPaused; + }); + + const resetButton = document.createElement('button'); + resetButton.textContent = 'Reset'; + resetButton.addEventListener('click', () => { + frame = 0; + resetColors(); + }); + + controlPanel.appendChild(playPauseButton); + controlPanel.appendChild(resetButton); + + containerElement.appendChild(controlPanel); + + const PAN_SPEED = 0.1; + const TILT_SPEED = 0.02; + const ZOOM_SPEED = 0.5; + let cameraRotation = new THREE.Euler(0, 0, 0, 'YXZ'); + + function onKeyDown(event) { + switch (event.key.toLowerCase()) { + case 'w': + camera.position.y += PAN_SPEED; + break; + case 's': + camera.position.y -= PAN_SPEED; + break; + case 'a': + camera.position.x -= PAN_SPEED; + break; + case 'd': + camera.position.x += PAN_SPEED; + break; + case 'arrowup': + cameraRotation.x -= TILT_SPEED; + break; + case 'arrowdown': + cameraRotation.x += TILT_SPEED; + break; + case 'arrowleft': + cameraRotation.y -= TILT_SPEED; + break; + case 'arrowright': + cameraRotation.y += TILT_SPEED; + break; + case 'o': + camera.position.z += ZOOM_SPEED; + break; + case 'p': + camera.position.z -= ZOOM_SPEED; + break; + } + camera.setRotationFromEuler(cameraRotation); + camera.updateProjectionMatrix(); + } + + window.addEventListener('resize', onResize); + window.addEventListener('keydown', onKeyDown); + containerElement.addEventListener('mousemove', onMouseMove); + + animate(); + + + function cleanup() { + window.removeEventListener('resize', onResize); + window.removeEventListener('keydown', onKeyDown); + containerElement.removeEventListener('mousemove', onMouseMove); + containerElement.innerHTML = ''; + renderer.dispose(); + scene.clear(); + } + + return cleanup; +} diff --git a/triton_viz/static/store.js b/triton_viz/static/store.js new file mode 100644 index 0000000..470b327 --- /dev/null +++ b/triton_viz/static/store.js @@ -0,0 +1,221 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; +import { + setupScene, + setupGeometries, + createTensor, + calculateTensorSize, + updateCubeColor, + setupCamera, + setupEventListeners, + cameraControls +} from './store-utils.js'; + +export function createStoreVisualization(containerElement, op) { + + console.log(op.uuid); + fetch('/api/setop', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid: op.uuid }), + }) + .then(response => response.json()) + .then(data => console.log('Set current op:', data)) + .catch((error) => console.error('Error:', error)); + + let currentStep = 0; + let frame = 0; + let isPaused = false; + + const sideMenu = createSideMenu(containerElement); + let hoveredCube = null; + + const COLOR_GLOBAL = new THREE.Color(0.2, 0.2, 0.2); // Dark Gray + const COLOR_SLICE = new THREE.Color(0.0, 0.7, 1.0); // Cyan (starting color for global slice) + const COLOR_LEFT_SLICE = new THREE.Color(1.0, 0.0, 1.0); // Magenta (starting color for left slice) + const COLOR_LOADED = new THREE.Color(1.0, 0.8, 0.0); // Gold (final color for both slices) + const COLOR_BACKGROUND = new THREE.Color(0.0, 0.0, 0.0); // Black + + const { scene, camera, renderer } = setupScene(containerElement, COLOR_BACKGROUND); + const { cubeGeometry, edgesGeometry, lineMaterial } = setupGeometries(); + + const globalTensor = createTensor(op.global_shape, op.global_coords, COLOR_GLOBAL, 'Global', cubeGeometry, edgesGeometry, lineMaterial); + const sliceTensor = createTensor(op.slice_shape, op.slice_coords, COLOR_LEFT_SLICE, 'Slice', cubeGeometry, edgesGeometry, lineMaterial); + + // Position slice tensor + const globalSize = calculateTensorSize(op.global_shape); + sliceTensor.position.set(globalSize.x + 5, 0, 0); // Adjusted tensor spacing + + scene.add(globalTensor); + scene.add(sliceTensor); + + addLabels(scene, globalTensor, sliceTensor); + setupCamera(scene, camera); + + const totalFrames = op.global_coords.length * 2 + 30; + + const raycaster = new THREE.Raycaster(); + const mouse = new THREE.Vector2(); + + const onKeyDown = cameraControls(camera, new THREE.Euler(0, 0, 0, 'YXZ')); + setupEventListeners(containerElement, camera, renderer, onMouseMove, onKeyDown); + animate(); + + async function onMouseMove(event) { + mouse.x = (event.clientX / containerElement.clientWidth) * 2 - 1; + mouse.y = -(event.clientY / containerElement.clientHeight) * 2 + 1; + + raycaster.setFromCamera(mouse, camera); + + const allTensorChildren = [ + ...globalTensor.children, + ...sliceTensor.children + ]; + + const intersects = raycaster.intersectObjects(allTensorChildren, true); + + if (hoveredCube) { + hoveredCube.getObjectByName('hoverOutline').visible = false; + hoveredCube = null; + } + + if (intersects.length > 0) { + hoveredCube = intersects[0].object; + while (hoveredCube && !hoveredCube.tensorName) { + hoveredCube = hoveredCube.parent; + } + + if (hoveredCube) { + const hoverOutline = hoveredCube.getObjectByName('hoverOutline'); + if (hoverOutline) { + hoverOutline.visible = true; + } + + updateSideMenu(hoveredCube.tensorName, hoveredCube.tensor0, hoveredCube.tensor1, hoveredCube.tensor2, undefined); + + const res = await getElementValue(hoveredCube.tensorName, hoveredCube.tensor0, hoveredCube.tensor1, hoveredCube.tensor2); + + updateSideMenu(hoveredCube.tensorName, hoveredCube.tensor0, hoveredCube.tensor1, hoveredCube.tensor2, res.value); + + console.log(`Value: ${res.value}`); + } + } else { + updateSideMenu(null); + } + } + + function animate() { + requestAnimationFrame(animate); + + if (!isPaused && frame < totalFrames) { + const index = Math.floor(frame / 2); + const factor = (frame % 2) / 1.0; + + if (index < op.global_coords.length) { + const globalCoord = op.global_coords[index]; + const sliceCoord = op.slice_coords[index]; + + updateCubeColor(globalTensor, globalCoord, COLOR_GLOBAL, COLOR_SLICE, factor); + updateCubeColor(sliceTensor, sliceCoord, COLOR_LEFT_SLICE, COLOR_LOADED, factor); + + highlightCurrentOperation(globalTensor, globalCoord, sliceTensor, sliceCoord); + updateInfoPanel(globalCoord, sliceCoord, index); + } + + frame++; + } + + renderer.render(scene, camera); + } + + function highlightCurrentOperation(globalTensor, globalCoord, sliceTensor, sliceCoord) { + globalTensor.children.forEach(cube => cube.material.emissive.setHex(0x000000)); + sliceTensor.children.forEach(cube => cube.material.emissive.setHex(0x000000)); + + const globalCube = globalTensor.children.find(c => + c.tensor0 === globalCoord[0] && c.tensor1 === globalCoord[1] && c.tensor2 === globalCoord[2] + ); + const sliceCube = sliceTensor.children.find(c => + c.tensor0 === sliceCoord[0] && c.tensor1 === sliceCoord[1] && c.tensor2 === sliceCoord[2] + ); + + if (globalCube) globalCube.material.emissive.setHex(0x444444); + if (sliceCube) sliceCube.material.emissive.setHex(0x444444); + } + + async function getElementValue(tensorName, x, y, z) { + let uuid = op.uuid; + const response = await fetch('/api/getStoreValue', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid, tensorName, x, y, z }), + }); + return await response.json(); + } + + function updateSideMenu(tensorName, x, y, z, value) { + if (!tensorName) { + sideMenu.innerHTML = ''; + return; + } + + let dims = tensorName === 'Global' ? op.global_shape : op.slice_shape; + sideMenu.innerHTML = ` +X: ${x + 1}
+Y: ${y + 1}
+Z: ${z + 1}
+Dimensions: ${dims.join(' x ')}
+Value: ${value !== undefined ? value : 'Storeing...'}
+ `; + } + + function updateInfoPanel(globalCoord, sliceCoord, index) { + sideMenu.innerHTML = ` +Global Coords: (${globalCoord.join(', ')})
+Slice Coords: (${sliceCoord.join(', ')})
+Progress: ${index + 1}/${op.global_coords.length}
+ `; + } + + function createSideMenu(container) { + const menu = document.createElement('div'); + menu.style.position = 'absolute'; + menu.style.top = '10px'; + menu.style.right = '10px'; + menu.style.width = '200px'; + menu.style.padding = '10px'; + menu.style.backgroundColor = 'rgba(0, 0, 0, 0.7)'; + menu.style.color = 'white'; + menu.style.fontFamily = 'Arial, sans-serif'; + menu.style.fontSize = '14px'; + menu.style.borderRadius = '5px'; + container.appendChild(menu); + return menu; + } + + function addLabels(scene, globalTensor, sliceTensor) { + addLabel(scene, "Global Tensor", globalTensor.position); + addLabel(scene, "Slice Tensor", sliceTensor.position); + } + + function addLabel(scene, text, position) { + const canvas = document.createElement('canvas'); + const context = canvas.getContext('2d'); + context.font = 'Bold 24px Arial'; + context.fillStyle = 'white'; + context.fillText(text, 0, 24); + + const texture = new THREE.CanvasTexture(canvas); + const material = new THREE.SpriteMaterial({ map: texture }); + const sprite = new THREE.Sprite(material); + sprite.position.set(position.x, position.y + 2, position.z); + sprite.scale.set(4, 2, 1); + scene.add(sprite); + } + +} diff --git a/triton_viz/static/visualization.js b/triton_viz/static/visualization.js new file mode 100644 index 0000000..e0d157c --- /dev/null +++ b/triton_viz/static/visualization.js @@ -0,0 +1,389 @@ +import { GridBlock } from './gridblock.js'; +import { createInfoPopup, showInfoPopup } from './infoPopup.js'; +let globalData; +let currentView = 'main'; +let canvas, ctx; +let maxX = 0, maxY = 0, maxZ = 0; +let sliders = [], zSlider, precomputeButton, kernelGrid; +let backButton; +let currentBlockData = null; +let isInitialized = false; +let containerElement; +let infoPopup; +let infoButton; +function switchToMainView() { + currentView = 'main'; + if (currentBlockData) { + currentBlockData.hideDetailedView(); + currentBlockData = null; + } + containerElement.style.pointerEvents = 'none'; + containerElement.style.display = 'none'; + containerElement.innerHTML = ''; + + canvas.style.display = 'block'; + draw(); +} + +function initializeApp() { + canvas = document.getElementById('canvas'); + if (!canvas) { + console.error('Canvas element not found'); + return; + } + ctx = canvas.getContext('2d'); + if (!ctx) { + console.error('Unable to get 2D context from canvas'); + return; + } + canvas.width = window.innerWidth; + canvas.height = window.innerHeight; + + canvas.addEventListener('mousedown', handleMouseEvent); + canvas.addEventListener('mouseup', handleMouseEvent); + canvas.addEventListener('mousemove', handleMouseEvent); + + containerElement = document.getElementById('visualization-container'); + if (!containerElement) { + console.error('Visualization container element not found'); + return; + } + + containerElement.style.pointerEvents = 'none'; + containerElement.style.display = 'none'; + + fetchData(); +} + +class Slider { + constructor(x, y, width, height, label, min_value = -1, max_value = 100) { + this.rect = { x, y, width, height }; + this.label = label; + this.min = min_value; + this.max = max_value; + this.value = min_value; + this.grabbed = false; + this.enabled = true; + } + + draw(ctx) { + if (!this.enabled) return; + ctx.fillStyle = '#3c3c46'; + ctx.fillRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + const buttonX = this.rect.x + (this.value - this.min) / (this.max - this.min) * this.rect.width; + ctx.fillStyle = '#c8c8c8'; + ctx.fillRect(buttonX - 5, this.rect.y - 2, 10, this.rect.height + 4); + + ctx.fillStyle = '#c8c8c8'; + ctx.font = '18px Arial'; + ctx.fillText(this.label, this.rect.x, this.rect.y - 10); + ctx.fillText(this.value.toString(), this.rect.x + this.rect.width + 10, this.rect.y + this.rect.height / 2 + 5); + } + + handleEvent(event) { + if (!this.enabled) return; + if (event.type === 'mousedown') { + if (this.isPointInside(event.offsetX, event.offsetY)) { + this.grabbed = true; + } + } else if (event.type === 'mouseup') { + this.grabbed = false; + } else if (event.type === 'mousemove' && this.grabbed) { + const mouseX = event.offsetX; + this.value = Math.round((mouseX - this.rect.x) / this.rect.width * (this.max - this.min) + this.min); + this.value = Math.max(this.min, Math.min(this.max, this.value)); + } + } + + isPointInside(x, y) { + return x >= this.rect.x && x <= this.rect.x + this.rect.width && + y >= this.rect.y && y <= this.rect.y + this.rect.height; + } +} + +class Button { + constructor(x, y, width, height, text, isIcon = false) { + this.rect = { x, y, width, height }; + this.text = text; + this.isIcon = isIcon; + this.color = '#3c3c46'; + this.hoverColor = '#50505a'; + this.clickColor = '#64646e'; + this.isHovered = false; + this.isClicked = false; + this.clickTime = 0; + } + + draw(ctx) { + let color = this.color; + if (this.isClicked && Date.now() - this.clickTime < 100) { + color = this.clickColor; + } else if (this.isHovered) { + color = this.hoverColor; + } + + ctx.fillStyle = color; + ctx.fillRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + ctx.strokeStyle = '#c8c8c8'; + ctx.strokeRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + + ctx.fillStyle = '#c8c8c8'; + ctx.font = this.isIcon ? 'bold 24px Arial' : '18px Arial'; + ctx.textAlign = 'center'; + ctx.textBaseline = 'middle'; + ctx.fillText(this.text, this.rect.x + this.rect.width / 2, this.rect.y + this.rect.height / 2); + } + + handleEvent(event) { + const { offsetX, offsetY } = event; + this.isHovered = this.isPointInside(offsetX, offsetY); + if (event.type === 'mousedown' && this.isHovered) { + this.isClicked = true; + this.clickTime = Date.now(); + console.log(`Button '${this.text}' clicked!`); + } else if (event.type === 'mouseup') { + this.isClicked = false; + } + } + + isPointInside(x, y) { + return x >= this.rect.x && x <= this.rect.x + this.rect.width && + y >= this.rect.y && y <= this.rect.y + this.rect.height; + } +} + +class KernelGrid { + constructor(x, y, width, height, gridSize, visualizationData) { + this.rect = { x, y, width, height }; + this.gridSize = gridSize; + this.visualizationData = visualizationData; + this.currentZ = 0; + this.blocks = []; + this.calculateBlockSize(); + this.createBlocks(); + this.selectedBlock = null; + this.filterValues = [-1, -1, -1]; // Default filter values for x, y, z + } + + calculateBlockSize() { + this.blockWidth = Math.floor(this.rect.width / this.gridSize[0]) - 1; + this.blockHeight = Math.floor(this.rect.height / this.gridSize[1]) - 1; + } + + createBlocks() { + this.blocks = []; + for (let y = 0; y < this.gridSize[1]; y++) { + for (let x = 0; x < this.gridSize[0]; x++) { + const blockX = this.rect.x + x * (this.blockWidth + 1); + const blockY = this.rect.y + y * (this.blockHeight + 1); + const gridKey = `(${x}, ${y}, ${this.currentZ})`; + const blockData = this.visualizationData[gridKey] || []; + const block = new GridBlock( + blockX, blockY, this.blockWidth, this.blockHeight, + x, y, this.currentZ, blockData, + switchToMainView, + containerElement, + canvas, + draw + ); + this.blocks.push(block); + } + } + } + + draw(ctx) { + ctx.fillStyle = '#F0F0F0'; + ctx.fillRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + this.blocks.forEach(block => { + if (this.shouldDrawBlock(block)) { + block.draw(ctx); + } + }); + } + + shouldDrawBlock(block) { + return (this.filterValues[0] === -1 || block.gridPosition.x === this.filterValues[0]) && + (this.filterValues[1] === -1 || block.gridPosition.y === this.filterValues[1]) && + (this.filterValues[2] === -1 || block.gridPosition.z === this.filterValues[2]); + } + + updateZ(z) { + this.currentZ = z; + this.filterValues[2] = z; + this.blocks.forEach(block => { + block.gridPosition.z = z; + const gridKey = `(${block.gridPosition.x}, ${block.gridPosition.y}, ${z})`; + block.blockData = this.visualizationData[gridKey] || []; + }); + } + + handleClick(x, y) { + const clickedBlock = this.blocks.find(block => + block.isPointInside(x, y) && this.shouldDrawBlock(block) + ); + if (clickedBlock) { + console.log(`Clicked block at (${clickedBlock.gridPosition.x}, ${clickedBlock.gridPosition.y}, ${clickedBlock.gridPosition.z})`); + if (this.selectedBlock) { + this.selectedBlock.hideDetailedView(); + } + this.selectedBlock = clickedBlock; + clickedBlock.showDetailedView(); + return clickedBlock; + } + return null; + } + + handleMouseMove(x, y) { + this.blocks.forEach(block => { + if (this.shouldDrawBlock(block)) { + block.handleMouseMove(x, y); + } else { + block.isHovered = false; + } + }); + } + + updateFilter(dimension, value) { + this.filterValues[dimension] = value; + } +} + +function determineMaxValues(visualizationData) { + maxX = 0; + maxY = 0; + maxZ = 0; + const keys = Object.keys(visualizationData); + keys.forEach(key => { + const [x, y, z] = key.replace(/[()]/g, '').split(', ').map(Number); + if (x > maxX) maxX = x; + if (y > maxY) maxY = y; + if (z > maxZ) maxZ = z; + }); +} + +function initializeUIElements() { + sliders = [ + new Slider(1300, 50, 250, 20, "Program Id 0", -1, maxX), + new Slider(1300, 120, 250, 20, "Program Id 1", -1, maxY), + new Slider(1300, 190, 250, 20, "Program Id 2", -1, maxZ) + ]; + + zSlider = new Slider(50, 860, 1200, 20, "Z-axis", 0, maxZ); + zSlider.enabled = maxZ > 0; + + precomputeButton = new Button(1300, 260, 250, 40, "Precompute"); + kernelGrid = new KernelGrid(50, 50, 1200, 800, [maxX + 1, maxY + 1, maxZ + 1], globalData.ops.visualization_data); + backButton = new Button(50, 50, 100, 40, "Back"); + const buttonSize = 40; + const margin = 10; + infoButton = new Button( + canvas.width - buttonSize - margin, + margin, + buttonSize, + buttonSize, + "i", + true + ); + + isInitialized = true; + + infoPopup = createInfoPopup(); +} + +function switchToTensorView(clickedBlock) { + currentView = 'tensor'; + currentBlockData = clickedBlock; + console.log("Switched to tensor view. Block data:", clickedBlock); + + containerElement.style.pointerEvents = 'auto'; + containerElement.style.display = 'block'; + clickedBlock.showDetailedView(); + + canvas.style.display = 'none'; +} + +function handleMouseEvent(event) { + if (!isInitialized) { + console.warn('UI elements not initialized yet'); + return; + } + if (infoButton) { + infoButton.handleEvent(event); + if (event.type === 'mousedown' && infoButton.isHovered) { + showInfoPopup(infoPopup); + } + } + const { offsetX, offsetY } = event; + if (currentView === 'main') { + sliders.forEach((slider, index) => { + slider.handleEvent(event); + if (kernelGrid) { + kernelGrid.updateFilter(index, slider.value); + } + }); + if (zSlider && zSlider.enabled) { + zSlider.handleEvent(event); + if (kernelGrid) { + kernelGrid.updateZ(zSlider.value); + } + } + if (precomputeButton) { + precomputeButton.handleEvent(event); + } + if (kernelGrid) { + kernelGrid.handleMouseMove(offsetX, offsetY); + if (event.type === 'mousedown') { + const clickedBlock = kernelGrid.handleClick(offsetX, offsetY); + if (clickedBlock) { + switchToTensorView(clickedBlock); + } + } + } + } else if (currentView === 'tensor') { + if (backButton) { + backButton.handleEvent(event); + if (event.type === 'mousedown' && backButton.isHovered) { + switchToMainView(); + } + } + } + draw(); +} + +function draw() { + if (!ctx) { + console.error('Canvas context not available'); + return; + } + + ctx.fillStyle = '#1e1e28'; + ctx.fillRect(0, 0, canvas.width, canvas.height); + + if (currentView === 'main' || currentView === 'main') { + if (kernelGrid) kernelGrid.draw(ctx); + sliders.forEach(slider => slider.draw(ctx)); + if (zSlider && zSlider.enabled) { + zSlider.draw(ctx); + } + if (precomputeButton) precomputeButton.draw(ctx); + if (infoButton) { + infoButton.draw(ctx); + } + } +} + +async function fetchData() { + try { + const response = await fetch('/api/data'); + globalData = await response.json(); + console.log(globalData); + + determineMaxValues(globalData.ops.visualization_data); + initializeUIElements(); + draw(); + } catch (error) { + console.error('Error fetching data:', error); + } +} + +initializeApp(); \ No newline at end of file diff --git a/triton_viz/templates/index.html b/triton_viz/templates/index.html new file mode 100644 index 0000000..b8eaebe --- /dev/null +++ b/triton_viz/templates/index.html @@ -0,0 +1,47 @@ + + + + + +