diff --git a/setup.py b/setup.py index 6898fed..264defc 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="triton-viz", - version="0.1", + version="0.2", packages=find_packages(), description="A visualization tool for Triton", author="Deep Learning Profiling Tools Team", @@ -11,8 +11,7 @@ install_requires=[ "setuptools", "triton", - "gradio", - "chalk-diagrams @ git+https://github.com/chalk-diagrams/chalk.git", + "flask", "pyarrow", "pre-commit", "pytest", diff --git a/triton_viz/__init__.py b/triton_viz/__init__.py index 737dd36..7c5bab8 100644 --- a/triton_viz/__init__.py +++ b/triton_viz/__init__.py @@ -1,5 +1,5 @@ from .trace import trace, dump, sample -from .draw import collect_grid, draw_record +from .draw import collect_grid from .interface import launch -__all__ = ["trace", "launch", "dump", "sample", "collect_grid", "draw_record"] +__all__ = ["trace", "launch", "dump", "sample", "collect_grid"] diff --git a/triton_viz/data.py b/triton_viz/data.py index 819a398..1d411bb 100644 --- a/triton_viz/data.py +++ b/triton_viz/data.py @@ -1,10 +1,10 @@ from dataclasses import dataclass, field -from typing import List, Tuple, Any +from typing import List, Tuple, Any, Dict import traceback import numpy.typing as npt import numpy as np - +import torch @dataclass class Op: call_path: List[traceback.StackSummary] = field(init=False, default_factory=list) @@ -68,11 +68,21 @@ class ExpandDims(Op): output_shape: Tuple + @dataclass class Dot(Op): input_shape: Tuple other_shape: Tuple output_shape: Tuple + input_data: List[List[float]] + other_data: List[List[float]] + intermediate_results: Dict[Tuple[int, int], float] = field(default_factory=dict) # Only storing the result now + + def update_intermediate(self, row: int, col: int, result: float): + # Store only the result as a float + self.intermediate_results[(row, col)] = result + + @dataclass @@ -91,6 +101,7 @@ class Tensor: stride: Tuple shape: Tuple element_size: int + data: torch.Tensor @dataclass diff --git a/triton_viz/draw.py b/triton_viz/draw.py index fda62c0..552de6a 100644 --- a/triton_viz/draw.py +++ b/triton_viz/draw.py @@ -1,69 +1,34 @@ -from colour import Color from triton_viz.data import ( Tensor, Grid, Store, Load, - Op, - MakeRange, - Reduce, Dot, + BinaryOp, + ExpandDims, + MakeRange ) +import uuid +import re +from triton.compiler import ASTSource + +import sys +from .trace import Trace from .interpreter import record_builder import numpy as np +import torch +from typing import Dict, Tuple, List +import ctypes import numpy.typing as npt -import planar -import math -import chalk -from typing import Tuple, Union, Optional, List, Dict -from chalk import Diagram, rectangle, text, hcat, vcat, empty, Path, Trail, V2, concat -from dataclasses import dataclass -from numpy.typing import ArrayLike -import sys - -sys.setrecursionlimit(100000) - - -planar.EPSILON = 0.0 -chalk.set_svg_draw_height(500) -BG = Color("white") -WHITE = Color("white") -DEFAULT = Color("grey") -BLACK = Color("black") -GREY = Color("grey") -palette = [ - "#f29f05", - "#f25c05", - "#d6568c", - "#4d8584", - "#a62f03", - "#400d01", - "#274001", - "#828a00", -] -ACTIVE = [Color(p) for p in palette] - -MRATIO = 1 / 3 - - -# Generic render helpers +import numpy as np +from typing import Tuple, Dict, List -def box(d: Diagram, width: float, height: float, outer=0.2) -> Diagram: - "Put diagram in a box of shape height, width" - h, w = d.get_envelope().height, d.get_envelope().width - m = max(h, w) - back = rectangle(outer + m, outer + m).line_width(0).fill_color(BG).center_xy() - d = (back + d.center_xy()).with_envelope(back) - return d.scale_x(width / m).scale_y(height / m) +import numpy as np +from typing import Union, Dict, List +import json -def reshape(d: Diagram) -> Diagram: - "Use log-scale if ratio is too sharp" - h, w = d.get_envelope().height, d.get_envelope().width - if (h / w > MRATIO) or (w / h > MRATIO): - d = d.scale_y(math.log(h + 1, 2) / h).scale_x(math.log(w + 1, 2) / w) - return d def collect_grid(): @@ -71,7 +36,6 @@ def collect_grid(): records, tensor_table, failures = collect_launch(launch) return records, tensor_table, failures - def collect_launch(launch): tensor_table = {} for i, t in enumerate(launch.tensors): @@ -96,375 +60,149 @@ def collect_launch(launch): return all_grids, tensor_table, failures -def draw_record(program_record, tensor_table, output): - return draw_launch(program_record, tensor_table, output) - - -def draw_launch(program_records, tensor_table, base) -> Diagram: - def draw(x): - "Dispatch" - if isinstance(x, Tensor): - return draw_tensor(x) - if isinstance(x, Grid): - return draw_grid(x) - if isinstance(x, Store): - return draw_store(x, tensor_table) - if isinstance(x, Load): - return draw_load(x, tensor_table) - if isinstance(x, Op): - return draw_op(x) - if isinstance(x, MakeRange): - return draw_make_range(x) - if isinstance(x, Reduce): - return draw_reduce(x) - if isinstance(x, Dot): - return None # draw_dot(x) - - def draw_record(x): - "Render one record" - y = draw(x) - if y is None: - return empty() - - return (chalk.vstrut(0.2) / y).center_xy() - - records = [] - for r in program_records: - dr = draw_record(r) - # env = dr.get_envelope() - # dr = dr.center_xy().with_envelope(rectangle(env.width, env.height).center_xy()) - records.append(dr) - - dr = vcat(records) - dr = dr.center_xy() - env = dr.get_envelope() - dr = rectangle(env.width + 1, env.height + 1).fill_color(BG).center_xy() + dr - dr.render_svg(base, 2500) - return env.width, env.height - - -def delinearize(shape: Tuple, x: npt.NDArray, dtype, mask) -> List[npt.NDArray]: - if len(shape) == 1: - shape = (1, 1, shape[0]) - x = x.copy() // (dtype.element_ty.primitive_bitwidth // 8) - vals = [] - for s in list(reversed(shape[1:])) + [10000]: - vals.append(((x % s) * mask - (1 - mask)).ravel()) - x = x // s - return vals -trail = Trail.from_offsets([V2(0, 1), V2(1, 0), V2(0, -1), V2(-1, 0)], closed=True) - - -def cover( - shape: Tuple, dtype, load: Tensor, mask: npt.NDArray, color: Color -) -> Diagram: - shape = make_3d(shape) - "Draw the values from load on top of the loading tensor" - x, y, z = delinearize(shape, load, dtype, mask) - return draw_tensor_3d(shape, z, y, x, color) - - -def pair_draw(x: Diagram, y: Diagram, command: str) -> Diagram: - "Draw two diagrams next to each other with a command in the middle." - return hcat([box(x, 3, 2.5), box(y, 3, 2.5)], 1).center_xy() + text( - command, 0.2 - ).fill_color(BLACK).line_width(0).translate(0, -1) - - -# Individual renderers - - -def draw_tensor(x: Tensor) -> Optional[Diagram]: - return None - - -def draw_grid(x: Grid) -> Optional[Diagram]: - return None - - -def draw_make_range(x: MakeRange) -> Optional[Diagram]: - return None - - -def draw_reduce(x: Reduce) -> Optional[Diagram]: - color = ACTIVE[0] - inp = draw_tensor_3d(make_3d(x.input_shape), None, None, None, color) - if x.index == 0 and len(x.input_shape) == 2: - inp = hcat( - [ - rectangle(0.1, inp.get_envelope().height) - .align_t() - .line_width(0) - .fill_color(BLACK), - inp, - ], - 0.5, - ) - else: - inp = vcat( - [ - rectangle(inp.get_envelope().width, 0.1) - .align_l() - .line_width(0) - .fill_color(BLACK), - inp, - ], - 0.5, - ) - out = draw_tensor_3d(x.output_shape, None, None, None, color) - return pair_draw(reshape(inp), reshape(out), x.op) - - -def draw_load(x, tensor_table) -> Optional[Diagram]: - inp, out = store_load(x, tensor_table) - out = reshape(out) - return pair_draw(inp, out, "load") - - -def draw_store(x, tensor_table) -> Optional[Diagram]: - inp, out = store_load(x, tensor_table) - out = reshape(out) - return pair_draw(out, inp, "store") +def extract_load_coords(record: Load, global_tensor: Tensor) -> Tuple[List[Tuple[float, float, float]], List[Tuple[float, float, float]]]: + # Extract coordinates for the global tensor + global_shape = make_3d(global_tensor.shape) + global_z, global_y, global_x = delinearized(global_shape, record.original_offsets, global_tensor.dtype, record.original_masks) + + global_coords = [ + (float(xi), float(yi), float(zi)) + for xi, yi, zi in zip(global_z, global_y, global_x) + if xi != -1 and yi != -1 and zi != -1 + ] + + + # Extract coordinates for the slice tensor + slice_shape = make_3d(record.shape) + slice_z, slice_y, slice_x = record.original_masks.reshape(*slice_shape).nonzero() + + slice_coords = [ + (float(xi), float(yi), float(zi)) + for xi, yi, zi in zip(slice_x, slice_y, slice_z) + ] + return global_coords, slice_coords -def make_3d(shape): - "Make a 3d shape" +def make_3d(shape: Tuple[int, ...]) -> Tuple[int, int, int]: if len(shape) == 1: return (1, 1, shape[0]) if len(shape) == 2: return (1, shape[0], shape[1]) return shape - -def store_load( - x: Union[Store, Load], tensor_table: Dict[int, Tuple[Tensor, int]] -) -> Tuple[Diagram, Diagram]: - tensor, tensor_id = tensor_table[x.ptr] - # inp = base_tensor(tensor.shape, DEFAULT) - color = ACTIVE[tensor_id] - invalid = x.invalid_access_masks.any() - if invalid: - color = Color("red") - inp = cover(tensor.shape, tensor.dtype, x.original_offsets, x.original_masks, color) - inp = reshape(inp) - s = make_3d(x.original_offsets.shape) - a, b, c = x.original_masks.reshape(*s).nonzero() - out = draw_tensor_3d(s, a, b, c, color) - return inp, out - - -def draw_op(x: Op) -> Optional[Diagram]: - return None - - -def draw_dot(x: Dot) -> Optional[Diagram]: - if x.input_shape == (1,): - return None - inp = draw_tensor_3d(x.input_shape[0], None, None, None) - # inp = reshape(base_tensor(x.input_shape[0], color=ACTIVE)) - # inp = add_whiskers(inp, x.input_shape[0]) - inp2 = draw_tensor_3d(x.input_shape[1], None, None, None) - # inp2 = reshape(base_tensor(x.input_shape[0], color=ACTIVE)) - # inp2 = add_whiskers(inp2, x.input_shape[0]) - out = draw_tensor_3d(x.output_shape, None, None, None) - # out = reshape(base_tensor(x.output_shape, color=ACTIVE)) - # out = add_whiskers(out, x.output_shape) - return hcat( - [box(inp, 1.5, 2), box(inp2, 1.5, 2), box(out, 1.5, 2)], 1 - ).center_xy() + text("dot", 0.2).fill_color(BLACK).line_width(0).translate(0, -1) - - -# For 3d - - -def lookAt(eye: ArrayLike, center: ArrayLike, up: ArrayLike): - "Python version of the haskell lookAt function in linear.projections" - f = (center - eye) / np.linalg.norm(center - eye) - s = np.cross(f, up) / np.linalg.norm(np.cross(f, up)) - u = np.cross(s, f) - return np.array([[*s, 0], [*u, 0], [*-f, 0], [0, 0, 0, 1]]) - - -def scale3(x, y, z): - return np.array([[x, 0, 0, 0], [0, y, 0, 0], [0, 0, z, 0], [0, 0, 0, 1]]) - - -@dataclass -class D3: - x: float - y: float - z: float - - def to_np(self): - return np.array([self.x, self.y, self.z]) - - -V3 = D3 - - -def homogeneous(trails: List[List[D3]]): - "Convert list of directions to a np.array of homogeneous coordinates" - return np.array([[[*o.to_np(), 1] for o in offsets] for offsets in trails]) - - -def cube(): - "3 faces of a cube drawn as offsets from the origin." - return homogeneous( - [ - [D3(*v) for v in offset] - for offset in [ - [(1, 0, 0), (0, 1, 0), (-1, 0, 0), (0, -1, 0)], - [(1, 0, 0), (0, 0, 1), (-1, 0, 0), (0, 0, -1)], - [(0, 0, 1), (0, 1, 0), (0, 0, -1), (0, -1, 0)], - ] - ] - ) - - -def to_trail(trail: ArrayLike, locations: ArrayLike): - return [ - ( - Path( - [ - Trail.from_offsets([V2(*v[:2]) for v in trail]) - .close() - .at(V2(*loc[:2])) - ] - ), - loc[2], - ) - for loc in locations - ] - - -def project(projection, shape3, positions): - p = homogeneous([positions for _ in range(shape3.shape[0])]) - locations = p @ projection.T - trails = shape3 @ projection.T - return [out for t, loc in zip(trails, locations) for out in to_trail(t, loc)] - - -def draw_tensor_3d(shape, a, b, c, color=WHITE): - shape = make_3d(shape) - - # Big Cube - s = scale3(*shape) - big_cube = cube() @ s.T - back = scale3(0, shape[1], shape[2]) - back_cube = cube() @ back.T - - # Isometric projection of tensor - projection = lookAt( - V3(-1.0, -0.3, -0.15).to_np(), - V3(0, 0, 0).to_np(), - V3(0, 1, 0).to_np(), - ) - outer = project(projection, big_cube, [V3(0, 0, 0)]) - outer2 = project(projection, back_cube, [V3(shape[0], 0, 0)]) - d = ( - concat([p.stroke().fill_color(GREY).fill_opacity(0.1) for p, _ in outer2]) - .line_width(0.005) - .line_color(GREY) - ) - d += ( - concat([p.stroke().fill_color(GREY).fill_opacity(0.05) for p, _ in outer]) - .line_width(0.01) - .line_color(BLACK) - ) - if a is not None: - out = group(a, b, c) - d2 = [ - (b, loc) - for i in range(len(out)) - for b, loc in make_cube(projection, out[i][0], out[i][1], color) - ] - d2.sort(key=lambda x: x[1], reverse=True) - d2 = concat([b.with_envelope(empty()) for b, _ in d2]) - d = d2.with_envelope(d) + d - return d - - -def lines(s): - "Draw lines to mimic a cube of cubes" - bs = [np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])] - return [ - homogeneous( - [ - [D3(*p) for _ in range(s[i]) for p in [a / s[i], b, -b]] - for b in bs - if not np.all(a == b) - ] - ) - for i, a in enumerate(bs) - ] - - -def make_cube(projection, start, end, color): - "Draws a cube from start position to end position." - start = np.array(start).astype(int) - end = np.array(end).astype(int) - s2 = end - start + 1 - s = scale3(*s2) - small_cube = cube() @ s.T - loc = [ - project(projection, l2 @ s.T, [V3(*start)]) - for l2 in lines(s2) - if l2.shape[1] > 0 - ] - outer2 = project(projection, small_cube, [V3(*start)]) - ls = loc - box = [ - (p.stroke().fill_color(color).fill_opacity(0.4).line_width(0), loc) - for p, loc in outer2 - ] - line = [ - (p.stroke().line_width(0.001).line_color(BLACK), l_) - for loc in ls - for p, l_ in loc - ] - return [(b, loc) for b, loc in box + line] - - -def group( - x: ArrayLike, y: ArrayLike, z: ArrayLike -) -> List[Tuple[Tuple[float, float, float], Tuple[float, float, float]]]: - "Groups together cubes into bigger cubes" - x = list(zip(zip(x, y, z), zip(x, y, z))) - x = [(a, b) for a, b in x if not (a[0] == -1 and a[1] == -1 and a[2] == -1)] - - start = x - - def remove_dups(ls): - "Remove duplicates" - out = [] - for y in ls: - if not out or y != out[-1]: - out.append(y) - return out - - for j in range(2, -1, -1): - x = remove_dups(start) - start = [] - while True: - if len(x) <= 1: - break - _, _, rest = x[0], x[1], x[2:] - m = 0 - for k in range(2): - a = x[0][k] - b = x[1][k] - if ( - (k == 0 or a[j % 3] == b[j % 3] - 1) - and a[(j + 1) % 3] == b[(j + 1) % 3] - and a[(j + 2) % 3] == b[(j + 2) % 3] - ): - m += 1 - if m == 2: - x = [[x[0][0], x[1][1]]] + rest - else: - start.append(x[0]) - x = [x[1]] + rest - start += x - return start +def delinearized(shape: Tuple[int, int, int], x: np.ndarray, dtype, mask) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + x = x.copy() // (dtype.element_ty.primitive_bitwidth // 8) + z = ((x // (shape[1] * shape[2])) * mask - (1 - mask)).ravel() + y = (((x // shape[2]) % shape[1]) * mask - (1 - mask)).ravel() + x = ((x % shape[2]) * mask - (1 - mask)).ravel() + return z, y, x + + + +def prepare_visualization_data(program_records, tensor_table): + + """Prepare visualization data for the frntend and raw tensor data for the server.""" + # global idx + visualization_data = [] + raw_tensor_data = {} + for record in program_records: + record_uuid = str(uuid.uuid4())[:8] + + if isinstance(record, ExpandDims): + print(record.input_shape, record.output_shape, record.index) + if isinstance(record, Dot): + visualization_data.append({ + 'type': 'Dot', + 'input_shape': record.input_shape, + 'other_shape': record.other_shape, + 'output_shape': record.output_shape, + 'uuid': record_uuid + }) + + raw_tensor_data[record_uuid] = { + 'input_data': torch.tensor(record.input_data), + 'other_data': torch.tensor(record.other_data), + 'intermediate_results': record.intermediate_results + } + + + + elif isinstance(record, Load): + global_tensor, slice_tensor = tensor_table[record.ptr] + print(global_tensor) + global_coords ,slice_coords = extract_load_coords(record,global_tensor) + # printc(global_coords,'green') + + visualization_data.append({ + 'type': 'Load', + 'global_shape': global_tensor.shape, + 'slice_shape':record.shape, + 'global_coords':global_coords, + 'slice_coords':slice_coords, + 'uuid': record_uuid + }) + + + raw_tensor_data[record_uuid] = { + 'global_tensor': global_tensor.data.cpu(), # Ensure it's on CPU + 'dims':len(global_tensor.data.cpu().shape) + } + print(record.shape) + + + elif isinstance(record, Store): + global_tensor, slice_tensor = tensor_table[record.ptr] + + global_coords ,slice_coords = extract_load_coords(record,global_tensor) + + + visualization_data.append({ + 'type': 'Store', + 'global_shape': global_tensor.shape, + 'slice_shape':record.shape, + 'global_coords':global_coords, + 'slice_coords':slice_coords, + 'uuid': record_uuid + }) + + + + + return visualization_data, raw_tensor_data, "" + + +def get_visualization_data(): + """Return the visualization data and raw tensor data.""" + records, tensor_table, failures = collect_grid() + visualization_data = {} + raw_tensor_data = {} + + for grid_idx, program_records in records.items(): + viz_data, raw_data, kernel_src = prepare_visualization_data(program_records, tensor_table) + visualization_data[str(grid_idx)] = viz_data + raw_tensor_data.update(raw_data) + + # Get the kernel source code + + + return { + "visualization_data": visualization_data, + "raw_tensor_data": raw_tensor_data, + "failures": failures, + "kernel_src": kernel_src + } + + + +def serialize_for_json(obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, torch.Tensor): + return obj.cpu().numpy().tolist() + raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable') \ No newline at end of file diff --git a/triton_viz/interface.py b/triton_viz/interface.py index 472f01d..9456c1b 100644 --- a/triton_viz/interface.py +++ b/triton_viz/interface.py @@ -1,99 +1,180 @@ -import gradio as gr -import triton_viz -import tempfile +import threading +import webbrowser +import json +from flask import Flask, render_template, jsonify, request from .analysis import analyze_records -from .tooltip import create_tooltip +from .draw import get_visualization_data +from .tooltip import get_tooltip_data import pandas as pd +import os +import torch -def launch(share=True): - cache = {} + +app = Flask(__name__, + template_folder=os.path.join(os.path.dirname(__file__), 'templates'), + static_folder=os.path.join(os.path.dirname(__file__), 'static')) + +# Global variables to store the data +global_data = None +raw_tensor_data = None +precomputed_c_values = {} +current_fullscreen_op = None + +def precompute_c_values(op_data): + input_data = op_data['input_data'] + other_data = op_data['other_data'] + rows, inner_dim = input_data.shape + cols = other_data.shape[1] + + precomputed = {} + for i in range(rows): + for j in range(cols): + precomputed[(i, j)] = [0] * (inner_dim + 1) + for k in range(1, inner_dim + 1): + precomputed[(i, j)][k] = torch.dot(input_data[i, :k], other_data[:k, j]).item() + + return precomputed + +def update_global_data(): + global global_data, raw_tensor_data, precomputed_c_values analysis_data = analyze_records() - program_records, tt, failures = triton_viz.collect_grid() - m = [0, 0, 0] - size = [0, 0] - for k in program_records.keys(): - m[0] = max(k[0] + 1, m[0]) - m[1] = max(k[1] + 1, m[1]) - m[2] = max(k[2] + 1, m[2]) - w, h = triton_viz.draw_record(program_records[(0, 0, 0)], tt, "tmp.svg") - size[0] = w - size[1] = h - height = 600 * size[1] / size[0] - with gr.Blocks( - css=".gradio-container button {overflow: auto} img.with-caption {height: %fpx !important; } .thumbnails { display: none; } " - % height - ) as demo: - with gr.Row(): - with gr.Column(scale=3, min_width=500): - img = gr.Gallery( - height=500, - min_width=500, - show_label=False, - selected_index=0, - preview=True, - object_fit="cover", - ) - with gr.Column(scale=1): - s1 = gr.Slider(0, m[0] - 1, value=0, step=1, label="Program Id 0") - s2 = gr.Slider(0, m[1] - 1, value=0, step=1, label="Program Id 1") - s3 = gr.Slider(0, m[2] - 1, value=0, step=1, label="Program Id 2") - b1 = gr.Button("Precompute") - gr.Markdown("## Analysis") - df = pd.DataFrame(analysis_data, columns=["Metric", "Value"]) - analysis_with_tooltip = create_tooltip(df) - gr.HTML(analysis_with_tooltip) - if failures: - gr.Markdown( - show_label=False, - value="## Invalid memory access in " - + "\n * " - + "\n* ".join(list(map(str, failures.keys()))), - ) - - def cache_block(idx): - name = tempfile.NamedTemporaryFile(suffix=".svg") - w, h = triton_viz.draw_record(program_records[idx], tt, name.name) - size[0] = w - size[1] = h - cache[idx] = (name, len(cache)) - - def update(inp): - a = inp[s1] - b = inp[s2] - c = inp[s3] - idx = (a, b, c) - - if idx not in cache: - cache_block(idx) - return gr.Gallery( - value=[(cache[k][0].name, str(k)) for k in cache.keys()], - selected_index=cache[idx][1], - height=700, - ), gr.Slider() - # * size[1]/size[0] - return gr.Gallery(selected_index=cache[idx][1]), gr.Slider() - - def precompute(inp): - a = inp[s1] - b = inp[s2] - c = inp[s3] - idx = (a, b, c) - for i in range(m[0]): - for j in range(m[1]): - for k in range(m[2]): - if (i, j, k) not in cache: - cache_block((i, j, k)) - return gr.Gallery( - value=[(cache[k][0].name, str(k)) for k in cache.keys()], - selected_index=cache[idx][1], - ) - - s1.change(update, inputs={s1, s2, s3}, outputs=[img, b1], show_progress=False) - s2.change(update, inputs={s1, s2, s3}, outputs=[img, b1], show_progress=False) - s3.change(update, inputs={s1, s2, s3}, outputs=[img, b1], show_progress=False) - b1.click(precompute, inputs={s1, s2, s3}, outputs=img, show_progress=True) - demo.load(update, inputs={s1, s2, s3}, outputs=[img, b1]) - - demo.launch(share=share, debug=False, height=800, quiet=True, show_api=False) - return failures + viz_data = get_visualization_data() + global_data = { + "ops": { + "visualization_data": viz_data["visualization_data"], + "failures": viz_data["failures"], + "kernel_src": viz_data["kernel_src"] + } + + } + raw_tensor_data = viz_data["raw_tensor_data"] + + # Precompute C values for each Dot operation + precomputed_c_values = {} + for uuid, op_data in raw_tensor_data.items(): + if 'input_data' in op_data and 'other_data' in op_data: + precomputed_c_values[uuid] = precompute_c_values(op_data) + + df = pd.DataFrame(analysis_data, columns=["Metric", "Value"]) + analysis_with_tooltip = get_tooltip_data(df) + global_data["analysis"] = analysis_with_tooltip + +@app.route('/') +def index(): + return render_template('index.html') + +@app.route('/api/data') +def get_data(): + global global_data + if global_data is None: + update_global_data() + return jsonify(global_data) + +@app.route('/api/update_data') +def update_data(): + update_global_data() + return jsonify({"status": "Data updated successfully"}) + +@app.route('/api/setop', methods=['POST']) +def set_current_op(): + global current_fullscreen_op + data = request.json + current_fullscreen_op = data.get('uuid') + return jsonify({"status": "Current op set successfully", "uuid": current_fullscreen_op}) + +@app.route('/api/getValue', methods=['POST']) +def get_value(): + global raw_tensor_data, precomputed_c_values, current_fullscreen_op + print(current_fullscreen_op) + data = request.json + uuid = data.get('uuid') + matrix_name = data.get('matrixName') + row = data.get('row') + col = data.get('col') + + if uuid not in raw_tensor_data: + return jsonify({"error": "Operation not found"}), 404 + + op_data = raw_tensor_data[uuid] + + if matrix_name == 'A': + value = op_data['input_data'][row, col].item() if 'input_data' in op_data else None + return jsonify({"value": value}) + elif matrix_name == 'B': + value = op_data['other_data'][row, col].item() if 'other_data' in op_data else None + return jsonify({"value": value}) + elif matrix_name == 'C': + current_step = data.get('currentStep', 0) + print(current_step) + total_steps = op_data['input_data'].shape[1] + + if uuid not in precomputed_c_values: + return jsonify({"error": "Precomputed values not found"}), 404 + + precomputed = precomputed_c_values[uuid] + current_value = precomputed[(row, col)][current_step] + print(current_value) + + return jsonify({ + "value": current_value, + + }) + else: + return jsonify({"error": "Invalid matrix name"}), 400 + + + +@app.route('/api/getLoadValue', methods=['POST']) +def get_load_value(): + global raw_tensor_data, current_fullscreen_op + + data = request.json + uuid = data.get('uuid') + x = data.get('x') + y = data.get('y') + z = data.get('z') + print(x,y,z) + if uuid is None or uuid not in raw_tensor_data: + return jsonify({"error": "Operation not found"}), 404 + + op_data = raw_tensor_data[uuid] + + if 'global_tensor' in op_data and (x is not None and y is not None and z is not None): + try: + + + printc(op_data['dims']) + value = 0.0 + if op_data['dims'] == 3: + value = op_data['global_tensor'][x, y,z].item() + elif op_data['dims'] == 2: + value = op_data['global_tensor'][x, y].item() + elif op_data['dims'] == 1: + value = op_data['global_tensor'][x].item() + + return jsonify({"value": value}) + except IndexError: + return jsonify({"error": "Coordinates out of bounds"}), 200 + else: + return jsonify({"error": "Global tensor data not found"}), 200 + + + + + + + +def run_flask(): + app.run(port=5000,host='127.0.0.1') + +def launch(share=True): + print("Launching Triton viz tool") + flask_thread = threading.Thread(target=run_flask) + flask_thread.start() + return flask_thread + +# This function can be called to stop the Flask server if needed +def stop_server(flask_thread): + # Implement a way to stop the Flask server + pass \ No newline at end of file diff --git a/triton_viz/interpreter.py b/triton_viz/interpreter.py index f447bbf..2fc02e5 100644 --- a/triton_viz/interpreter.py +++ b/triton_viz/interpreter.py @@ -44,6 +44,20 @@ def _unpatch_lang(): class RecordBuilder: + + def to_dict(self): + """Convert the recorded data to a dictionary format for JSON serialization.""" + return { + "launches": [ + { + "grid": launch.grid, + "tensors": [tensor.__dict__ for tensor in launch.tensors], + "records": [record.__dict__ for record in launch.records] + } + for launch in self._launches + ] + } + def __init__(self) -> None: self.reset() @@ -136,7 +150,11 @@ def _check_storage_contiguous(tensor): def _grid_executor_call(self, *args_dev, **kwargs): # Removes reserved keywords from kwargs + #kwargs has src_map and src src is the raw straight spurce code, src_map is a mapping between source code line and an op kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} + src_map = kwargs.pop("src_map") + src = kwargs.pop("src") + if kwargs.pop("warmup", False): return args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) @@ -152,6 +170,7 @@ def _grid_executor_call(self, *args_dev, **kwargs): else: ret = _implicit_cvt(arg) if hasattr(arg, "data_ptr"): + assert _check_storage_contiguous( arg ), "triton-viz only supports contiguouly stored tensors for now" @@ -162,6 +181,7 @@ def _grid_executor_call(self, *args_dev, **kwargs): arg.stride(), arg.shape, arg.element_size(), + arg ) ) call_args[name] = ret @@ -291,18 +311,33 @@ def wrapper(lhs, rhs, op): def _create_dot(fn): @wraps(fn) - def wrapper(a, b, d, allow_tf32, maxNumImpreciseAcc): - ret = fn(a, b, d, allow_tf32, maxNumImpreciseAcc) + def wrapper(a, b, c, allow_tf32, max_num_imprecise_acc): dot_record = Dot( - input_shape=(a.data.shape, b.data.shape), - other_shape=d.data.shape, - output_shape=ret.data.shape, + input_shape=a.data.shape, + other_shape=b.data.shape, + output_shape=c.data.shape, + input_data=a.data.tolist(), + other_data=b.data.tolist() ) record_builder.add_record(dot_record) - return ret - return wrapper + def capture_intermediate(row, col, result): + dot_record.update_intermediate(row, col, float(result)) + # Modify the original function to call capture_intermediate at each step + def modified_fn(a, b, c, allow_tf32, max_num_imprecise_acc): + for i in range(a.data.shape[0]): + for j in range(b.data.shape[1]): + A_row = a.data[i, :] + B_column = b.data[:, j] + result = np.dot(A_row, B_column) + capture_intermediate(i, j, result) + c.data[i, j] = result + + modified_fn(a, b, c, allow_tf32, max_num_imprecise_acc) + return c + + return wrapper def _create_expand_dims(fn): @wraps(fn) @@ -344,7 +379,7 @@ def patch(): old_grid_executor_call = GridExecutor.__call__ old_jit_function_call = JITFunction.__call__ # XXX(Keren): Temporarily disable rewriting of AST - old_rewrite_ast = InterpretedFunction._rewrite_ast + # old_rewrite_ast = InterpretedFunction._rewrite_ast old_create_make_range = interpreter_builder.create_make_range old_create_masked_load = interpreter_builder.create_masked_load old_create_expand_dims = interpreter_builder.create_expand_dims @@ -373,10 +408,16 @@ def patch(): finally: GridExecutor.__call__ = old_grid_executor_call JITFunction.__call__ = old_jit_function_call - InterpretedFunction._rewrite_ast = old_rewrite_ast + # InterpretedFunction._rewrite_ast = old_rewrite_ast interpreter_builder.create_make_range = old_create_make_range interpreter_builder.create_masked_load = old_create_masked_load interpreter_builder.create_expand_dims = old_create_expand_dims interpreter_builder.binary_op = old_binary_op interpreter_builder.create_dot = old_create_dot interpreter_builder.create_masked_store = old_create_masked_store + + + +def get_recorded_data(): + """Return the recorded data in a format suitable for JSON serialization.""" + return record_builder.to_dict() \ No newline at end of file diff --git a/triton_viz/static/gridblock.js b/triton_viz/static/gridblock.js new file mode 100644 index 0000000..8d3000e --- /dev/null +++ b/triton_viz/static/gridblock.js @@ -0,0 +1,204 @@ +import { createMatMulVisualization } from './matmul.js' +import { createLoadVisualization } from './load.js'; +import { createStoreVisualization } from './store.js'; + +export class GridBlock { + constructor(x, y, width, height, gridX, gridY, gridZ, blockData) { + this.rect = { x, y, width, height }; + this.gridPosition = { x: gridX, y: gridY, z: gridZ }; + this.blockData = blockData; + this.isHovered = false; + this.visualizationContainer = null; + this.cleanupFunction = null; + this.fullScreenCleanupFunction = null; + } + + draw(ctx) { + // Draw background + ctx.fillStyle = this.isHovered ? '#4a4a4a' : '#323232'; + ctx.fillRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + + // Draw border + ctx.strokeStyle = '#000000'; + ctx.lineWidth = 1; + ctx.strokeRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + + // Draw grid position + ctx.fillStyle = '#c8c8c8'; + ctx.font = '12px Arial'; + ctx.textAlign = 'center'; + ctx.textBaseline = 'top'; + const posText = `${this.gridPosition.x},${this.gridPosition.y},${this.gridPosition.z}`; + ctx.fillText(posText, this.rect.x + this.rect.width / 2, this.rect.y + 2); + + // Draw operation types + ctx.font = '10px Arial'; + ctx.textAlign = 'left'; + ctx.textBaseline = 'top'; + this.blockData.forEach((op, index) => { + ctx.fillText(op.type, this.rect.x + 5, this.rect.y + 20 + index * 15); + }); + } + + isPointInside(x, y) { + return x >= this.rect.x && x <= this.rect.x + this.rect.width && + y >= this.rect.y && y <= this.rect.y + this.rect.height; + } + + handleMouseMove(x, y) { + const wasHovered = this.isHovered; + this.isHovered = this.isPointInside(x, y); + return wasHovered !== this.isHovered; // Return true if hover state changed + } + + showDetailedView(containerElement) { + if (this.isDetailedViewVisible) return; + + this.visualizationContainer = containerElement; + this.visualizationContainer.innerHTML = ''; + this.visualizationContainer.style.pointerEvents = 'auto'; + + const wrapper = document.createElement('div'); + wrapper.style.padding = '10px'; + wrapper.style.backgroundColor = '#1e1e28'; + wrapper.style.color = '#ffffff'; + wrapper.style.height = '100%'; + wrapper.style.overflowY = 'auto'; + wrapper.style.display = 'flex'; + wrapper.style.flexDirection = 'column'; + wrapper.style.alignItems = 'stretch'; + wrapper.style.width = '100%'; + + const title = document.createElement('h2'); + title.textContent = `Operations at (${this.gridPosition.x}, ${this.gridPosition.y}, ${this.gridPosition.z})`; + title.style.textAlign = 'center'; + title.style.margin = '10px 0'; + wrapper.appendChild(title); + + const operationsContainer = document.createElement('div'); + operationsContainer.style.display = 'flex'; + operationsContainer.style.flexDirection = 'column'; + operationsContainer.style.gap = '10px'; + operationsContainer.style.flex = '1'; + + this.blockData.forEach((op, index) => { + const opContainer = document.createElement('div'); + opContainer.style.padding = '10px'; + opContainer.style.border = '1px solid #444'; + opContainer.style.borderRadius = '5px'; + opContainer.style.display = 'flex'; + opContainer.style.flexDirection = 'column'; + opContainer.style.flex = '1'; + + const opTitle = document.createElement('h3'); + opTitle.textContent = `Operation ${index + 1}: ${op.type}`; + opTitle.style.margin = '0 0 10px 0'; + opContainer.appendChild(opTitle); + + if (op.type === 'Dot' || op.type === 'Load' || op.type === 'Store') { + const viewButton = document.createElement('button'); + viewButton.textContent = 'View Full Screen'; + viewButton.style.marginTop = '10px'; + viewButton.addEventListener('click', () => this.showFullScreenOperation(op)); + opContainer.appendChild(viewButton); + } else { + const unsupportedMsg = document.createElement('p'); + unsupportedMsg.textContent = `Full-screen view not supported for ${op.type} operation`; + opContainer.appendChild(unsupportedMsg); + } + + operationsContainer.appendChild(opContainer); + }); + + wrapper.appendChild(operationsContainer); + + if (this.blockData.length === 0) { + const noDataMsg = document.createElement('p'); + noDataMsg.textContent = 'No operation data available'; + noDataMsg.style.textAlign = 'center'; + wrapper.appendChild(noDataMsg); + } + + const closeButton = document.createElement('button'); + closeButton.textContent = 'Close'; + closeButton.style.position = 'fixed'; + closeButton.style.top = '10px'; + closeButton.style.right = '10px'; + closeButton.addEventListener('click', () => this.hideDetailedView()); + + this.visualizationContainer.appendChild(wrapper); + this.visualizationContainer.appendChild(closeButton); + + this.visualizationContainer.tabIndex = 0; + this.visualizationContainer.focus(); + + this.isDetailedViewVisible = true; + this.cleanupFunction = this.hideDetailedView.bind(this); + } + + hideDetailedView() { + if (!this.isDetailedViewVisible) return; + + if (this.visualizationContainer) { + this.visualizationContainer.innerHTML = ''; + this.visualizationContainer.style.pointerEvents = 'none'; + } + + this.isDetailedViewVisible = false; + this.cleanupFunction = null; + } + + showFullScreenOperation(op) { + if (op.type !== 'Dot' && op.type !== 'Load' && op.type !== 'Store') { + console.warn(`Full-screen view is not supported for ${op.type} operations`); + return; + } + + this.hideDetailedView(); + + const fullScreenContainer = document.createElement('div'); + fullScreenContainer.style.position = 'fixed'; + fullScreenContainer.style.top = '0'; + fullScreenContainer.style.left = '0'; + fullScreenContainer.style.width = '100vw'; + fullScreenContainer.style.height = '100vh'; + fullScreenContainer.style.backgroundColor = '#1e1e28'; + fullScreenContainer.style.zIndex = '1000'; + + document.body.appendChild(fullScreenContainer); + + if (op.type === 'Dot') { + this.fullScreenCleanupFunction = createMatMulVisualization(fullScreenContainer, op); + } else if (op.type === 'Load') { + this.fullScreenCleanupFunction = createLoadVisualization(fullScreenContainer, op); + } + else if (op.type === 'Store') { + console.log("store op",op) + this.fullScreenCleanupFunction = createStoreVisualization(fullScreenContainer, op); + } + + const closeButton = document.createElement('button'); + closeButton.textContent = 'Close Full Screen'; + closeButton.style.position = 'fixed'; + closeButton.style.top = '10px'; + closeButton.style.right = '10px'; + closeButton.style.zIndex = '1001'; + closeButton.addEventListener('click', () => this.closeFullScreenOperation()); + + fullScreenContainer.appendChild(closeButton); + } + + closeFullScreenOperation() { + if (this.fullScreenCleanupFunction) { + this.fullScreenCleanupFunction(); + this.fullScreenCleanupFunction = null; + } + + const fullScreenContainer = document.querySelector('div[style*="position: fixed"]'); + if (fullScreenContainer) { + document.body.removeChild(fullScreenContainer); + } + + this.showDetailedView(this.visualizationContainer); + } +} \ No newline at end of file diff --git a/triton_viz/static/load.js b/triton_viz/static/load.js new file mode 100644 index 0000000..ced95ec --- /dev/null +++ b/triton_viz/static/load.js @@ -0,0 +1,221 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; +import { + setupScene, + setupGeometries, + createTensor, + calculateTensorSize, + updateCubeColor, + setupCamera, + setupEventListeners, + cameraControls +} from './load_utils.js'; + +export function createLoadVisualization(containerElement, op) { + + console.log(op.uuid); + fetch('/api/setop', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid: op.uuid }), + }) + .then(response => response.json()) + .then(data => console.log('Set current op:', data)) + .catch((error) => console.error('Error:', error)); + + let currentStep = 0; + let frame = 0; + let isPaused = false; + + const sideMenu = createSideMenu(containerElement); + let hoveredCube = null; + + const COLOR_GLOBAL = new THREE.Color(0.2, 0.2, 0.2); // Dark Gray + const COLOR_SLICE = new THREE.Color(0.0, 0.7, 1.0); // Cyan (starting color for global slice) + const COLOR_LEFT_SLICE = new THREE.Color(1.0, 0.0, 1.0); // Magenta (starting color for left slice) + const COLOR_LOADED = new THREE.Color(1.0, 0.8, 0.0); // Gold (final color for both slices) + const COLOR_BACKGROUND = new THREE.Color(0.0, 0.0, 0.0); // Black + + const { scene, camera, renderer } = setupScene(containerElement, COLOR_BACKGROUND); + const { cubeGeometry, edgesGeometry, lineMaterial } = setupGeometries(); + + const globalTensor = createTensor(op.global_shape, op.global_coords, COLOR_GLOBAL, 'Global', cubeGeometry, edgesGeometry, lineMaterial); + const sliceTensor = createTensor(op.slice_shape, op.slice_coords, COLOR_LEFT_SLICE, 'Slice', cubeGeometry, edgesGeometry, lineMaterial); + + // Position slice tensor + const globalSize = calculateTensorSize(op.global_shape); + sliceTensor.position.set(globalSize.x + 5, 0, 0); // Adjusted tensor spacing + + scene.add(globalTensor); + scene.add(sliceTensor); + + addLabels(scene, globalTensor, sliceTensor); + setupCamera(scene, camera); + + const totalFrames = op.global_coords.length * 2 + 30; + + const raycaster = new THREE.Raycaster(); + const mouse = new THREE.Vector2(); + + const onKeyDown = cameraControls(camera, new THREE.Euler(0, 0, 0, 'YXZ')); + setupEventListeners(containerElement, camera, renderer, onMouseMove, onKeyDown); + animate(); + + async function onMouseMove(event) { + mouse.x = (event.clientX / containerElement.clientWidth) * 2 - 1; + mouse.y = -(event.clientY / containerElement.clientHeight) * 2 + 1; + + raycaster.setFromCamera(mouse, camera); + + const allTensorChildren = [ + ...globalTensor.children, + ...sliceTensor.children + ]; + + const intersects = raycaster.intersectObjects(allTensorChildren, true); + + if (hoveredCube) { + hoveredCube.getObjectByName('hoverOutline').visible = false; + hoveredCube = null; + } + + if (intersects.length > 0) { + hoveredCube = intersects[0].object; + while (hoveredCube && !hoveredCube.tensorName) { + hoveredCube = hoveredCube.parent; + } + + if (hoveredCube) { + const hoverOutline = hoveredCube.getObjectByName('hoverOutline'); + if (hoverOutline) { + hoverOutline.visible = true; + } + + updateSideMenu(hoveredCube.tensorName, hoveredCube.tensorX, hoveredCube.tensorY, hoveredCube.tensorZ, undefined); + + const res = await getElementValue(hoveredCube.tensorName, hoveredCube.tensorX, hoveredCube.tensorY, hoveredCube.tensorZ); + + updateSideMenu(hoveredCube.tensorName, hoveredCube.tensorX, hoveredCube.tensorY, hoveredCube.tensorZ, res.value); + + console.log(`Value: ${res.value}`); + } + } else { + updateSideMenu(null); + } + } + + function animate() { + requestAnimationFrame(animate); + + if (!isPaused && frame < totalFrames) { + const index = Math.floor(frame / 2); + const factor = (frame % 2) / 1.0; + + if (index < op.global_coords.length) { + const globalCoord = op.global_coords[index]; + const sliceCoord = op.slice_coords[index]; + + updateCubeColor(globalTensor, globalCoord, COLOR_GLOBAL, COLOR_SLICE, factor); + updateCubeColor(sliceTensor, sliceCoord, COLOR_LEFT_SLICE, COLOR_LOADED, factor); + + highlightCurrentOperation(globalTensor, globalCoord, sliceTensor, sliceCoord); + updateInfoPanel(globalCoord, sliceCoord, index); + } + + frame++; + } + + renderer.render(scene, camera); + } + + function highlightCurrentOperation(globalTensor, globalCoord, sliceTensor, sliceCoord) { + globalTensor.children.forEach(cube => cube.material.emissive.setHex(0x000000)); + sliceTensor.children.forEach(cube => cube.material.emissive.setHex(0x000000)); + + const globalCube = globalTensor.children.find(c => + c.tensorX === globalCoord[0] && c.tensorY === globalCoord[1] && c.tensorZ === globalCoord[2] + ); + const sliceCube = sliceTensor.children.find(c => + c.tensorX === sliceCoord[0] && c.tensorY === sliceCoord[1] && c.tensorZ === sliceCoord[2] + ); + + if (globalCube) globalCube.material.emissive.setHex(0x444444); + if (sliceCube) sliceCube.material.emissive.setHex(0x444444); + } + + async function getElementValue(tensorName, x, y, z) { + let uuid = op.uuid; + const response = await fetch('/api/getLoadValue', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid, tensorName, x, y, z }), + }); + return await response.json(); + } + + function updateSideMenu(tensorName, x, y, z, value) { + if (!tensorName) { + sideMenu.innerHTML = ''; + return; + } + + let dims = tensorName === 'Global' ? op.global_shape : op.slice_shape; + sideMenu.innerHTML = ` +
X: ${x + 1}
+Y: ${y + 1}
+Z: ${z + 1}
+Dimensions: ${dims.join(' x ')}
+Value: ${value !== undefined ? value : 'Loading...'}
+ `; + } + + function updateInfoPanel(globalCoord, sliceCoord, index) { + sideMenu.innerHTML = ` +Global Coords: (${globalCoord.join(', ')})
+Slice Coords: (${sliceCoord.join(', ')})
+Progress: ${index + 1}/${op.global_coords.length}
+ `; + } + + function createSideMenu(container) { + const menu = document.createElement('div'); + menu.style.position = 'absolute'; + menu.style.top = '10px'; + menu.style.right = '10px'; + menu.style.width = '200px'; + menu.style.padding = '10px'; + menu.style.backgroundColor = 'rgba(0, 0, 0, 0.7)'; + menu.style.color = 'white'; + menu.style.fontFamily = 'Arial, sans-serif'; + menu.style.fontSize = '14px'; + menu.style.borderRadius = '5px'; + container.appendChild(menu); + return menu; + } + + function addLabels(scene, globalTensor, sliceTensor) { + addLabel(scene, "Global Tensor", globalTensor.position); + addLabel(scene, "Slice Tensor", sliceTensor.position); + } + + function addLabel(scene, text, position) { + const canvas = document.createElement('canvas'); + const context = canvas.getContext('2d'); + context.font = 'Bold 24px Arial'; + context.fillStyle = 'white'; + context.fillText(text, 0, 24); + + const texture = new THREE.CanvasTexture(canvas); + const material = new THREE.SpriteMaterial({ map: texture }); + const sprite = new THREE.Sprite(material); + sprite.position.set(position.x, position.y + 2, position.z); + sprite.scale.set(4, 2, 1); + scene.add(sprite); + } + +} diff --git a/triton_viz/static/load_utils.js b/triton_viz/static/load_utils.js new file mode 100644 index 0000000..c139082 --- /dev/null +++ b/triton_viz/static/load_utils.js @@ -0,0 +1,200 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; + +export const CUBE_SIZE = 0.2; +export const GAP = 0.05; +export const COLOR_HOVER = new THREE.Color(1.0, 1.0, 0.0); // Yellow for hover effect +export const COLOR_EDGE = new THREE.Color(0.5, 0.5, 0.5); // Gray (for cube edges) + +const COLOR_SLICE = new THREE.Color(0.0, 0.7, 1.0); // Cyan (starting color for global slice) + +export function setupScene(container, backgroundColor = 0x000000) { + const scene = new THREE.Scene(); + scene.background = new THREE.Color(backgroundColor); + const camera = new THREE.PerspectiveCamera(45, container.clientWidth / container.clientHeight, 0.1, 1000); + const renderer = new THREE.WebGLRenderer({ antialias: true }); + renderer.setSize(container.clientWidth, container.clientHeight); + container.appendChild(renderer.domElement); + + const ambientLight = new THREE.AmbientLight(0xffffff, 0.5); + scene.add(ambientLight); + const directionalLight = new THREE.DirectionalLight(0xffffff, 0.5); + directionalLight.position.set(10, 10, 10); + scene.add(directionalLight); + + return { scene, camera, renderer }; +} + +export function setupGeometries() { + const cubeGeometry = new THREE.BoxGeometry(CUBE_SIZE, CUBE_SIZE, CUBE_SIZE); + const edgesGeometry = new THREE.EdgesGeometry(cubeGeometry); + const lineMaterial = new THREE.LineBasicMaterial({ color: COLOR_EDGE }); + return { cubeGeometry, edgesGeometry, lineMaterial }; +} + +export function createCube(color, tensorName, x, y, z, cubeGeometry, edgesGeometry, lineMaterial) { + const cubeMaterial = new THREE.MeshPhongMaterial({ color: color }); + const cube = new THREE.Mesh(cubeGeometry, cubeMaterial); + const edges = new THREE.LineSegments(edgesGeometry, lineMaterial); + cube.add(edges); + + const hoverGeometry = new THREE.BoxGeometry(CUBE_SIZE * 1.05, CUBE_SIZE * 1.05, CUBE_SIZE * 1.05); + const hoverEdgesGeometry = new THREE.EdgesGeometry(hoverGeometry); + const hoverOutline = new THREE.LineSegments(hoverEdgesGeometry, new THREE.LineBasicMaterial({ color: COLOR_HOVER })); + hoverOutline.visible = false; + hoverOutline.name = 'hoverOutline'; + cube.add(hoverOutline); + + // Add custom properties to store tensor coordinates + cube.userData.tensorX = z; + cube.userData.tensorY = y; + cube.userData.tensorZ = x; + cube.userData.tensorName = tensorName; + + cube.name = `${tensorName}_cube_${x}_${y}_${z}`; + + return cube; +} + +export function createTensor(shape, coords, color, tensorName, cubeGeometry, edgesGeometry, lineMaterial) { + console.log(`Creating ${tensorName} tensor:`, shape, coords); + const tensor = new THREE.Group(); + let [width, height, depth] = shape; + depth = depth || 1; + height = height || 1; + + if (tensorName === 'Global') { + console.log(`Creating global tensor with dimensions: ${width}x${height}x${depth}`); + for (let z = 0; z < depth; z++) { + for (let y = 0; y < height; y++) { + for (let x = 0; x < width; x++) { + const cube = createCube(color, tensorName, x, y, z, cubeGeometry, edgesGeometry, lineMaterial); + cube.position.set( + x * (CUBE_SIZE + GAP), + -y * (CUBE_SIZE + GAP), + -z * (CUBE_SIZE + GAP) + ); + tensor.add(cube); + } + } + } + + console.log(`Highlighting ${coords.length} coordinates in global tensor`); + coords.forEach(([x, y, z]) => { + const cube = tensor.children.find(c => + c.userData.tensorX === x && c.userData.tensorY === y && c.userData.tensorZ === z + ); + if (cube) { + cube.material.color.set(COLOR_SLICE); + console.log(`Highlighted cube at (${x}, ${y}, ${z})`); + } else { + console.warn(`Could not find cube at (${x}, ${y}, ${z})`); + } + }); + } else { + console.log(`Creating slice tensor with ${coords.length} coordinates`); + coords.forEach(([x, y, z]) => { + const cube = createCube(color, tensorName, x, y, z, cubeGeometry, edgesGeometry, lineMaterial); + cube.position.set( + x * (CUBE_SIZE + GAP), + -y * (CUBE_SIZE + GAP), + -z * (CUBE_SIZE + GAP) + ); + tensor.add(cube); + }); + } + + console.log(`Created ${tensorName} tensor with ${tensor.children.length} cubes`); + return tensor; +} + +export function calculateTensorSize(shape) { + const [width, height, depth] = shape; + return new THREE.Vector3( + width * (CUBE_SIZE + GAP), + height * (CUBE_SIZE + GAP), + depth * (CUBE_SIZE + GAP) + ); +} + +export function interpolateColor(color1, color2, factor) { + return new THREE.Color().lerpColors(color1, color2, factor); +} + +export function updateCubeColor(tensor, coord, startColor, endColor, factor) { + const cube = tensor.children.find(c => + c.tensorX === coord[0] && c.tensorY === coord[1] && c.tensorZ === coord[2] + ); + if (cube) { + cube.material.color.copy(interpolateColor(startColor, endColor, factor)); + } +} + +export function setupCamera(scene, camera) { + const box = new THREE.Box3().setFromObject(scene); + const center = box.getCenter(new THREE.Vector3()); + const size = box.getSize(new THREE.Vector3()); + const maxDim = Math.max(size.x, size.y, size.z); + const fov = camera.fov * (Math.PI / 180); + let cameraZ = Math.abs(maxDim / 2 / Math.tan(fov / 2)); + cameraZ *= 1.5; + + camera.position.set(center.x, center.y, center.z + cameraZ); + camera.lookAt(center); + + return { center, cameraZ }; +} + +export function setupEventListeners(containerElement, camera, renderer, onMouseMove, onKeyDown) { + window.addEventListener('resize', () => { + camera.aspect = containerElement.clientWidth / containerElement.clientHeight; + camera.updateProjectionMatrix(); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + }); + containerElement.addEventListener('mousemove', onMouseMove); + window.addEventListener('keydown', onKeyDown); +} + +export function cameraControls(camera, cameraRotation) { + const PAN_SPEED = 0.1; + const TILT_SPEED = 0.02; + const ZOOM_SPEED = 0.5; + + return function onKeyDown(event) { + switch (event.key.toLowerCase()) { + case 'w': + camera.position.y += PAN_SPEED; + break; + case 's': + camera.position.y -= PAN_SPEED; + break; + case 'a': + camera.position.x -= PAN_SPEED; + break; + case 'd': + camera.position.x += PAN_SPEED; + break; + case 'arrowup': + cameraRotation.x -= TILT_SPEED; + break; + case 'arrowdown': + cameraRotation.x += TILT_SPEED; + break; + case 'arrowleft': + cameraRotation.y -= TILT_SPEED; + break; + case 'arrowright': + cameraRotation.y += TILT_SPEED; + break; + case 'o': + camera.position.z += ZOOM_SPEED; + break; + case 'p': + camera.position.z -= ZOOM_SPEED; + break; + case ' ': + break; + } + camera.setRotationFromEuler(cameraRotation); + camera.updateProjectionMatrix(); + }; +} diff --git a/triton_viz/static/matmul.js b/triton_viz/static/matmul.js new file mode 100644 index 0000000..834df9d --- /dev/null +++ b/triton_viz/static/matmul.js @@ -0,0 +1,360 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; + +export function createMatMulVisualization(containerElement, op) { + const { input_shape, other_shape, output_shape } = op; + console.log(op.uuid) + fetch('/api/setop', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid: op.uuid }), + }) + .then(response => response.json()) + .then(data => console.log('Set current op:', data)) + .catch((error) => console.error('Error:', error)); + let currentStep = 0; + const totalSteps = input_shape[1]; + let frame = 0; + + + const sideMenu = document.createElement('div'); + sideMenu.style.position = 'absolute'; + sideMenu.style.top = '10px'; + sideMenu.style.right = '10px'; + sideMenu.style.width = '200px'; + sideMenu.style.padding = '10px'; + sideMenu.style.backgroundColor = 'rgba(0, 0, 0, 0.7)'; + sideMenu.style.color = 'white'; + sideMenu.style.fontFamily = 'Arial, sans-serif'; + sideMenu.style.fontSize = '14px'; + sideMenu.style.borderRadius = '5px'; + containerElement.appendChild(sideMenu); + let hoveredCube = null; + + + + + async function getElementValue( matrixName, row, col) { + let uuid = op.uuid; + const response = await fetch('/api/getValue', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid, matrixName, row, col, currentStep }), + }); + return await response.json(); + } + + + function updateSideMenu(matrix, x, y) { + if (!matrix) { + sideMenu.innerHTML = ''; + return; + } + + let matrixName; + let dims; + if (matrix === matrixA) { + matrixName = 'A'; + dims = input_shape; + } else if (matrix === matrixB) { + matrixName = 'B'; + dims = other_shape; + } else if (matrix === matrixC) { + matrixName = 'C'; + dims = output_shape; + } else { + sideMenu.innerHTML = ''; + return; + } + console.log(matrixName, "x:", (x + 1), "y:", (y + 1)); + sideMenu.innerHTML = ` +Row: ${y + 1}
+Column: ${x + 1}
+Dimensions: ${dims[0]} x ${dims[1]}
+ `; + } + + const raycaster = new THREE.Raycaster(); + const mouse = new THREE.Vector2(); + + + async function onMouseMove(event) { + mouse.x = (event.clientX / containerElement.clientWidth) * 2 - 1; + mouse.y = -(event.clientY / containerElement.clientHeight) * 2 + 1; + + raycaster.setFromCamera(mouse, camera); + + const allMatrixChildren = [ + ...(matrixA ? matrixA.children : []), + ...(matrixB ? matrixB.children : []), + ...(matrixC ? matrixC.children : []) + ]; + + const intersects = raycaster.intersectObjects(allMatrixChildren, true); + + if (hoveredCube) { + hoveredCube.getObjectByName('hoverOutline').visible = false; + hoveredCube = null; + } + + if (intersects.length > 0) { + // Find the actual cube (parent of the intersected object) + hoveredCube = intersects[0].object; + while (hoveredCube && !hoveredCube.matrixName) { + hoveredCube = hoveredCube.parent; + } + + if (hoveredCube) { + const hoverOutline = hoveredCube.getObjectByName('hoverOutline'); + if (hoverOutline) { + hoverOutline.visible = true; + } + const res = await getElementValue(hoveredCube.matrixName, hoveredCube.matrixRow, hoveredCube.matrixCol); + // Log the matrix name, row, and column of the hovered cube + console.log( + // `Matrix: ${hoveredCube.matrixName}, ` + + // `Row: ${hoveredCube.matrixRow + 1}, ` + + // `Column: ${hoveredCube.matrixCol + 1}`+ + `Value: ${res.value}` + ); + + updateSideMenu(hoveredCube.matrixName, hoveredCube.matrixRow, hoveredCube.matrixCol); + } + } else { + updateSideMenu(null); + } + } + + + + + + const CUBE_SIZE = 0.2; + const GAP = 0.05; + + const COLOR_A = new THREE.Color(0.53, 0.81, 0.98); + const COLOR_B = new THREE.Color(1.0, 0.65, 0.0); + const COLOR_C = new THREE.Color(1.0, 1.0, 1.0); + const COLOR_HIGHLIGHT = new THREE.Color(0.0, 0.0, 1.0); + const COLOR_FILLED = new THREE.Color(0.0, 0.0, 1.0); + const COLOR_BACKGROUND = new THREE.Color(0.0, 0.0, 0.0); + const COLOR_EDGE = new THREE.Color(0.3, 0.3, 0.3); + const COLOR_HOVER = new THREE.Color(1.0, 1.0, 0.0); + + const scene = new THREE.Scene(); + scene.background = COLOR_BACKGROUND; + const camera = new THREE.PerspectiveCamera(45, containerElement.clientWidth / containerElement.clientHeight, 0.1, 1000); + + const renderer = new THREE.WebGLRenderer({ antialias: true }); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + containerElement.appendChild(renderer.domElement); + + const ambientLight = new THREE.AmbientLight(0xffffff, 0.5); + scene.add(ambientLight); + const directionalLight = new THREE.DirectionalLight(0xffffff, 0.5); + directionalLight.position.set(10, 10, 10); + scene.add(directionalLight); + + const cubeGeometry = new THREE.BoxGeometry(CUBE_SIZE, CUBE_SIZE, CUBE_SIZE); + const edgesGeometry = new THREE.EdgesGeometry(cubeGeometry); + const lineMaterial = new THREE.LineBasicMaterial({ color: COLOR_EDGE }); + + function createCube(color, matrixName, i, j) { + const cubeMaterial = new THREE.MeshPhongMaterial({ color: color }); + const cube = new THREE.Mesh(cubeGeometry, cubeMaterial); + const edges = new THREE.LineSegments(edgesGeometry, lineMaterial); + cube.add(edges); + + const hoverGeometry = new THREE.BoxGeometry(CUBE_SIZE * 1.05, CUBE_SIZE * 1.05, CUBE_SIZE * 1.05); + const hoverEdgesGeometry = new THREE.EdgesGeometry(hoverGeometry); + const hoverOutline = new THREE.LineSegments(hoverEdgesGeometry, new THREE.LineBasicMaterial({ color: COLOR_HOVER })); + hoverOutline.visible = false; + hoverOutline.name = 'hoverOutline'; + cube.add(hoverOutline); + + cube.name = `${matrixName}_cube_${i}_${j}`; + cube.matrixName = matrixName; + cube.matrixRow = i; + cube.matrixCol = j; + + return cube; + } + + function createMatrix(dimensions, position, color, matrixName) { + const matrix = new THREE.Group(); + matrix.userData.dimensions = dimensions; + for (let i = 0; i < dimensions[0]; i++) { + for (let j = 0; j < dimensions[1]; j++) { + const cube = createCube(color, matrixName, i, j); + cube.position.set( + position.x + j * (CUBE_SIZE + GAP), + position.y - i * (CUBE_SIZE + GAP), + position.z + ); + matrix.add(cube); + } + } + return matrix; + } + + const matrixA = createMatrix(input_shape, new THREE.Vector3(-10, 10, 0), COLOR_A, 'A'); + const matrixB = createMatrix(other_shape, new THREE.Vector3(0, 10, 0), COLOR_B, 'B'); + const matrixC = createMatrix(output_shape, new THREE.Vector3(-5, -4, 0), COLOR_C, 'C'); + + scene.add(matrixA); + scene.add(matrixB); + scene.add(matrixC); + + const center = new THREE.Vector3(); + const size = new THREE.Vector3(); + const box = new THREE.Box3().setFromObject(scene); + box.getCenter(center); + box.getSize(size); + const maxDim = Math.max(size.x, size.y, size.z); + const fov = camera.fov * (Math.PI / 180); + let cameraZ = Math.abs(maxDim / 2 / Math.tan(fov / 2)); + cameraZ *= 1.5; + + camera.position.set(center.x, center.y, center.z + cameraZ); + camera.lookAt(center); + + let isPaused = false; + + const totalFrames = input_shape[0] * other_shape[1]; + + function highlightCubes(matrix, indices, highlightColor) { + indices.forEach(([i, j]) => { + if (i >= 0 && i < matrix.userData.dimensions[0] && j >= 0 && j < matrix.userData.dimensions[1]) { + const index = i * matrix.userData.dimensions[1] + j; + if (index < matrix.children.length) { + matrix.children[index].material.color.copy(highlightColor); + + } + } + }); + } + + function resetColors() { + matrixA.children.forEach(cube => cube.material.color.copy(COLOR_A)); + matrixB.children.forEach(cube => cube.material.color.copy(COLOR_B)); + } + + function animate() { + requestAnimationFrame(animate); + + if (!isPaused && frame < totalFrames) { + resetColors(); + + const row = Math.floor(frame / other_shape[1]); + const col = frame % other_shape[1]; + currentStep = frame % totalSteps + 1; + + const highlightA = Array.from({ length: input_shape[1] }, (_, i) => [row, i]); + const highlightB = Array.from({ length: other_shape[0] }, (_, i) => [i, col]); + const highlightC = [[row, col]]; + + highlightCubes(matrixA, highlightA, COLOR_HIGHLIGHT); + highlightCubes(matrixB, highlightB, COLOR_HIGHLIGHT); + highlightCubes(matrixC, highlightC, COLOR_FILLED); + + frame++; + } + + renderer.render(scene, camera); + } + + function onResize() { + camera.aspect = containerElement.clientWidth / containerElement.clientHeight; + camera.updateProjectionMatrix(); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + } + + const controlPanel = document.createElement('div'); + controlPanel.style.position = 'absolute'; + controlPanel.style.bottom = '10px'; + controlPanel.style.left = '10px'; + controlPanel.style.display = 'flex'; + controlPanel.style.gap = '10px'; + + const playPauseButton = document.createElement('button'); + playPauseButton.textContent = 'Play/Pause'; + playPauseButton.addEventListener('click', () => { + isPaused = !isPaused; + }); + + const resetButton = document.createElement('button'); + resetButton.textContent = 'Reset'; + resetButton.addEventListener('click', () => { + frame = 0; + resetColors(); + }); + + controlPanel.appendChild(playPauseButton); + controlPanel.appendChild(resetButton); + + containerElement.appendChild(controlPanel); + + const PAN_SPEED = 0.1; + const TILT_SPEED = 0.02; + const ZOOM_SPEED = 0.5; + let cameraRotation = new THREE.Euler(0, 0, 0, 'YXZ'); + + function onKeyDown(event) { + switch (event.key.toLowerCase()) { + case 'w': + camera.position.y += PAN_SPEED; + break; + case 's': + camera.position.y -= PAN_SPEED; + break; + case 'a': + camera.position.x -= PAN_SPEED; + break; + case 'd': + camera.position.x += PAN_SPEED; + break; + case 'arrowup': + cameraRotation.x -= TILT_SPEED; + break; + case 'arrowdown': + cameraRotation.x += TILT_SPEED; + break; + case 'arrowleft': + cameraRotation.y -= TILT_SPEED; + break; + case 'arrowright': + cameraRotation.y += TILT_SPEED; + break; + case 'o': + camera.position.z += ZOOM_SPEED; + break; + case 'p': + camera.position.z -= ZOOM_SPEED; + break; + } + camera.setRotationFromEuler(cameraRotation); + camera.updateProjectionMatrix(); + } + + window.addEventListener('resize', onResize); + window.addEventListener('keydown', onKeyDown); + containerElement.addEventListener('mousemove', onMouseMove); + + animate(); + + + function cleanup() { + window.removeEventListener('resize', onResize); + window.removeEventListener('keydown', onKeyDown); + containerElement.removeEventListener('mousemove', onMouseMove); + containerElement.innerHTML = ''; + renderer.dispose(); + scene.clear(); + } + + return cleanup; +} diff --git a/triton_viz/static/store-utils.js b/triton_viz/static/store-utils.js new file mode 100644 index 0000000..c139082 --- /dev/null +++ b/triton_viz/static/store-utils.js @@ -0,0 +1,200 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; + +export const CUBE_SIZE = 0.2; +export const GAP = 0.05; +export const COLOR_HOVER = new THREE.Color(1.0, 1.0, 0.0); // Yellow for hover effect +export const COLOR_EDGE = new THREE.Color(0.5, 0.5, 0.5); // Gray (for cube edges) + +const COLOR_SLICE = new THREE.Color(0.0, 0.7, 1.0); // Cyan (starting color for global slice) + +export function setupScene(container, backgroundColor = 0x000000) { + const scene = new THREE.Scene(); + scene.background = new THREE.Color(backgroundColor); + const camera = new THREE.PerspectiveCamera(45, container.clientWidth / container.clientHeight, 0.1, 1000); + const renderer = new THREE.WebGLRenderer({ antialias: true }); + renderer.setSize(container.clientWidth, container.clientHeight); + container.appendChild(renderer.domElement); + + const ambientLight = new THREE.AmbientLight(0xffffff, 0.5); + scene.add(ambientLight); + const directionalLight = new THREE.DirectionalLight(0xffffff, 0.5); + directionalLight.position.set(10, 10, 10); + scene.add(directionalLight); + + return { scene, camera, renderer }; +} + +export function setupGeometries() { + const cubeGeometry = new THREE.BoxGeometry(CUBE_SIZE, CUBE_SIZE, CUBE_SIZE); + const edgesGeometry = new THREE.EdgesGeometry(cubeGeometry); + const lineMaterial = new THREE.LineBasicMaterial({ color: COLOR_EDGE }); + return { cubeGeometry, edgesGeometry, lineMaterial }; +} + +export function createCube(color, tensorName, x, y, z, cubeGeometry, edgesGeometry, lineMaterial) { + const cubeMaterial = new THREE.MeshPhongMaterial({ color: color }); + const cube = new THREE.Mesh(cubeGeometry, cubeMaterial); + const edges = new THREE.LineSegments(edgesGeometry, lineMaterial); + cube.add(edges); + + const hoverGeometry = new THREE.BoxGeometry(CUBE_SIZE * 1.05, CUBE_SIZE * 1.05, CUBE_SIZE * 1.05); + const hoverEdgesGeometry = new THREE.EdgesGeometry(hoverGeometry); + const hoverOutline = new THREE.LineSegments(hoverEdgesGeometry, new THREE.LineBasicMaterial({ color: COLOR_HOVER })); + hoverOutline.visible = false; + hoverOutline.name = 'hoverOutline'; + cube.add(hoverOutline); + + // Add custom properties to store tensor coordinates + cube.userData.tensorX = z; + cube.userData.tensorY = y; + cube.userData.tensorZ = x; + cube.userData.tensorName = tensorName; + + cube.name = `${tensorName}_cube_${x}_${y}_${z}`; + + return cube; +} + +export function createTensor(shape, coords, color, tensorName, cubeGeometry, edgesGeometry, lineMaterial) { + console.log(`Creating ${tensorName} tensor:`, shape, coords); + const tensor = new THREE.Group(); + let [width, height, depth] = shape; + depth = depth || 1; + height = height || 1; + + if (tensorName === 'Global') { + console.log(`Creating global tensor with dimensions: ${width}x${height}x${depth}`); + for (let z = 0; z < depth; z++) { + for (let y = 0; y < height; y++) { + for (let x = 0; x < width; x++) { + const cube = createCube(color, tensorName, x, y, z, cubeGeometry, edgesGeometry, lineMaterial); + cube.position.set( + x * (CUBE_SIZE + GAP), + -y * (CUBE_SIZE + GAP), + -z * (CUBE_SIZE + GAP) + ); + tensor.add(cube); + } + } + } + + console.log(`Highlighting ${coords.length} coordinates in global tensor`); + coords.forEach(([x, y, z]) => { + const cube = tensor.children.find(c => + c.userData.tensorX === x && c.userData.tensorY === y && c.userData.tensorZ === z + ); + if (cube) { + cube.material.color.set(COLOR_SLICE); + console.log(`Highlighted cube at (${x}, ${y}, ${z})`); + } else { + console.warn(`Could not find cube at (${x}, ${y}, ${z})`); + } + }); + } else { + console.log(`Creating slice tensor with ${coords.length} coordinates`); + coords.forEach(([x, y, z]) => { + const cube = createCube(color, tensorName, x, y, z, cubeGeometry, edgesGeometry, lineMaterial); + cube.position.set( + x * (CUBE_SIZE + GAP), + -y * (CUBE_SIZE + GAP), + -z * (CUBE_SIZE + GAP) + ); + tensor.add(cube); + }); + } + + console.log(`Created ${tensorName} tensor with ${tensor.children.length} cubes`); + return tensor; +} + +export function calculateTensorSize(shape) { + const [width, height, depth] = shape; + return new THREE.Vector3( + width * (CUBE_SIZE + GAP), + height * (CUBE_SIZE + GAP), + depth * (CUBE_SIZE + GAP) + ); +} + +export function interpolateColor(color1, color2, factor) { + return new THREE.Color().lerpColors(color1, color2, factor); +} + +export function updateCubeColor(tensor, coord, startColor, endColor, factor) { + const cube = tensor.children.find(c => + c.tensorX === coord[0] && c.tensorY === coord[1] && c.tensorZ === coord[2] + ); + if (cube) { + cube.material.color.copy(interpolateColor(startColor, endColor, factor)); + } +} + +export function setupCamera(scene, camera) { + const box = new THREE.Box3().setFromObject(scene); + const center = box.getCenter(new THREE.Vector3()); + const size = box.getSize(new THREE.Vector3()); + const maxDim = Math.max(size.x, size.y, size.z); + const fov = camera.fov * (Math.PI / 180); + let cameraZ = Math.abs(maxDim / 2 / Math.tan(fov / 2)); + cameraZ *= 1.5; + + camera.position.set(center.x, center.y, center.z + cameraZ); + camera.lookAt(center); + + return { center, cameraZ }; +} + +export function setupEventListeners(containerElement, camera, renderer, onMouseMove, onKeyDown) { + window.addEventListener('resize', () => { + camera.aspect = containerElement.clientWidth / containerElement.clientHeight; + camera.updateProjectionMatrix(); + renderer.setSize(containerElement.clientWidth, containerElement.clientHeight); + }); + containerElement.addEventListener('mousemove', onMouseMove); + window.addEventListener('keydown', onKeyDown); +} + +export function cameraControls(camera, cameraRotation) { + const PAN_SPEED = 0.1; + const TILT_SPEED = 0.02; + const ZOOM_SPEED = 0.5; + + return function onKeyDown(event) { + switch (event.key.toLowerCase()) { + case 'w': + camera.position.y += PAN_SPEED; + break; + case 's': + camera.position.y -= PAN_SPEED; + break; + case 'a': + camera.position.x -= PAN_SPEED; + break; + case 'd': + camera.position.x += PAN_SPEED; + break; + case 'arrowup': + cameraRotation.x -= TILT_SPEED; + break; + case 'arrowdown': + cameraRotation.x += TILT_SPEED; + break; + case 'arrowleft': + cameraRotation.y -= TILT_SPEED; + break; + case 'arrowright': + cameraRotation.y += TILT_SPEED; + break; + case 'o': + camera.position.z += ZOOM_SPEED; + break; + case 'p': + camera.position.z -= ZOOM_SPEED; + break; + case ' ': + break; + } + camera.setRotationFromEuler(cameraRotation); + camera.updateProjectionMatrix(); + }; +} diff --git a/triton_viz/static/store.js b/triton_viz/static/store.js new file mode 100644 index 0000000..e35df29 --- /dev/null +++ b/triton_viz/static/store.js @@ -0,0 +1,221 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; +import { + setupScene, + setupGeometries, + createTensor, + calculateTensorSize, + updateCubeColor, + setupCamera, + setupEventListeners, + cameraControls +} from './store-utils.js'; + +export function createStoreVisualization(containerElement, op) { + + console.log(op.uuid); + fetch('/api/setop', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid: op.uuid }), + }) + .then(response => response.json()) + .then(data => console.log('Set current op:', data)) + .catch((error) => console.error('Error:', error)); + + let currentStep = 0; + let frame = 0; + let isPaused = false; + + const sideMenu = createSideMenu(containerElement); + let hoveredCube = null; + + const COLOR_GLOBAL = new THREE.Color(0.2, 0.2, 0.2); // Dark Gray + const COLOR_SLICE = new THREE.Color(0.0, 0.7, 1.0); // Cyan (starting color for global slice) + const COLOR_LEFT_SLICE = new THREE.Color(1.0, 0.0, 1.0); // Magenta (starting color for left slice) + const COLOR_LOADED = new THREE.Color(1.0, 0.8, 0.0); // Gold (final color for both slices) + const COLOR_BACKGROUND = new THREE.Color(0.0, 0.0, 0.0); // Black + + const { scene, camera, renderer } = setupScene(containerElement, COLOR_BACKGROUND); + const { cubeGeometry, edgesGeometry, lineMaterial } = setupGeometries(); + + const globalTensor = createTensor(op.global_shape, op.global_coords, COLOR_GLOBAL, 'Global', cubeGeometry, edgesGeometry, lineMaterial); + const sliceTensor = createTensor(op.slice_shape, op.slice_coords, COLOR_LEFT_SLICE, 'Slice', cubeGeometry, edgesGeometry, lineMaterial); + + // Position slice tensor + const globalSize = calculateTensorSize(op.global_shape); + sliceTensor.position.set(globalSize.x + 5, 0, 0); // Adjusted tensor spacing + + scene.add(globalTensor); + scene.add(sliceTensor); + + addLabels(scene, globalTensor, sliceTensor); + setupCamera(scene, camera); + + const totalFrames = op.global_coords.length * 2 + 30; + + const raycaster = new THREE.Raycaster(); + const mouse = new THREE.Vector2(); + + const onKeyDown = cameraControls(camera, new THREE.Euler(0, 0, 0, 'YXZ')); + setupEventListeners(containerElement, camera, renderer, onMouseMove, onKeyDown); + animate(); + + async function onMouseMove(event) { + mouse.x = (event.clientX / containerElement.clientWidth) * 2 - 1; + mouse.y = -(event.clientY / containerElement.clientHeight) * 2 + 1; + + raycaster.setFromCamera(mouse, camera); + + const allTensorChildren = [ + ...globalTensor.children, + ...sliceTensor.children + ]; + + const intersects = raycaster.intersectObjects(allTensorChildren, true); + + if (hoveredCube) { + hoveredCube.getObjectByName('hoverOutline').visible = false; + hoveredCube = null; + } + + if (intersects.length > 0) { + hoveredCube = intersects[0].object; + while (hoveredCube && !hoveredCube.tensorName) { + hoveredCube = hoveredCube.parent; + } + + if (hoveredCube) { + const hoverOutline = hoveredCube.getObjectByName('hoverOutline'); + if (hoverOutline) { + hoverOutline.visible = true; + } + + updateSideMenu(hoveredCube.tensorName, hoveredCube.tensorX, hoveredCube.tensorY, hoveredCube.tensorZ, undefined); + + const res = await getElementValue(hoveredCube.tensorName, hoveredCube.tensorX, hoveredCube.tensorY, hoveredCube.tensorZ); + + updateSideMenu(hoveredCube.tensorName, hoveredCube.tensorX, hoveredCube.tensorY, hoveredCube.tensorZ, res.value); + + console.log(`Value: ${res.value}`); + } + } else { + updateSideMenu(null); + } + } + + function animate() { + requestAnimationFrame(animate); + + if (!isPaused && frame < totalFrames) { + const index = Math.floor(frame / 2); + const factor = (frame % 2) / 1.0; + + if (index < op.global_coords.length) { + const globalCoord = op.global_coords[index]; + const sliceCoord = op.slice_coords[index]; + + updateCubeColor(globalTensor, globalCoord, COLOR_GLOBAL, COLOR_SLICE, factor); + updateCubeColor(sliceTensor, sliceCoord, COLOR_LEFT_SLICE, COLOR_LOADED, factor); + + highlightCurrentOperation(globalTensor, globalCoord, sliceTensor, sliceCoord); + updateInfoPanel(globalCoord, sliceCoord, index); + } + + frame++; + } + + renderer.render(scene, camera); + } + + function highlightCurrentOperation(globalTensor, globalCoord, sliceTensor, sliceCoord) { + globalTensor.children.forEach(cube => cube.material.emissive.setHex(0x000000)); + sliceTensor.children.forEach(cube => cube.material.emissive.setHex(0x000000)); + + const globalCube = globalTensor.children.find(c => + c.tensorX === globalCoord[0] && c.tensorY === globalCoord[1] && c.tensorZ === globalCoord[2] + ); + const sliceCube = sliceTensor.children.find(c => + c.tensorX === sliceCoord[0] && c.tensorY === sliceCoord[1] && c.tensorZ === sliceCoord[2] + ); + + if (globalCube) globalCube.material.emissive.setHex(0x444444); + if (sliceCube) sliceCube.material.emissive.setHex(0x444444); + } + + async function getElementValue(tensorName, x, y, z) { + let uuid = op.uuid; + const response = await fetch('/api/getStoreValue', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ uuid, tensorName, x, y, z }), + }); + return await response.json(); + } + + function updateSideMenu(tensorName, x, y, z, value) { + if (!tensorName) { + sideMenu.innerHTML = ''; + return; + } + + let dims = tensorName === 'Global' ? op.global_shape : op.slice_shape; + sideMenu.innerHTML = ` +X: ${x + 1}
+Y: ${y + 1}
+Z: ${z + 1}
+Dimensions: ${dims.join(' x ')}
+Value: ${value !== undefined ? value : 'Storeing...'}
+ `; + } + + function updateInfoPanel(globalCoord, sliceCoord, index) { + sideMenu.innerHTML = ` +Global Coords: (${globalCoord.join(', ')})
+Slice Coords: (${sliceCoord.join(', ')})
+Progress: ${index + 1}/${op.global_coords.length}
+ `; + } + + function createSideMenu(container) { + const menu = document.createElement('div'); + menu.style.position = 'absolute'; + menu.style.top = '10px'; + menu.style.right = '10px'; + menu.style.width = '200px'; + menu.style.padding = '10px'; + menu.style.backgroundColor = 'rgba(0, 0, 0, 0.7)'; + menu.style.color = 'white'; + menu.style.fontFamily = 'Arial, sans-serif'; + menu.style.fontSize = '14px'; + menu.style.borderRadius = '5px'; + container.appendChild(menu); + return menu; + } + + function addLabels(scene, globalTensor, sliceTensor) { + addLabel(scene, "Global Tensor", globalTensor.position); + addLabel(scene, "Slice Tensor", sliceTensor.position); + } + + function addLabel(scene, text, position) { + const canvas = document.createElement('canvas'); + const context = canvas.getContext('2d'); + context.font = 'Bold 24px Arial'; + context.fillStyle = 'white'; + context.fillText(text, 0, 24); + + const texture = new THREE.CanvasTexture(canvas); + const material = new THREE.SpriteMaterial({ map: texture }); + const sprite = new THREE.Sprite(material); + sprite.position.set(position.x, position.y + 2, position.z); + sprite.scale.set(4, 2, 1); + scene.add(sprite); + } + +} diff --git a/triton_viz/static/three.js b/triton_viz/static/three.js new file mode 100644 index 0000000..e69de29 diff --git a/triton_viz/static/visualization.js b/triton_viz/static/visualization.js new file mode 100644 index 0000000..735558e --- /dev/null +++ b/triton_viz/static/visualization.js @@ -0,0 +1,358 @@ +import * as THREE from 'https://cdn.jsdelivr.net/npm/three@0.155.0/build/three.module.js'; +import { GridBlock } from './gridblock.js'; +import { createLoadVisualization } from './load.js'; +import { createStoreVisualization } from './store.js'; +import { createMatMulVisualization } from './matmul.js'; + +let globalData; +let currentView = 'main'; +let canvas, ctx; +let maxX = 0, maxY = 0, maxZ = 0; +let sliders = [], zSlider, precomputeButton, kernelGrid; +let backButton; +let currentBlockData = null; +let isInitialized = false; +let containerElement; + +// Initialize the application +function switchToMainView() { + currentView = 'main'; + if (currentBlockData) { + currentBlockData.hideDetailedView(); + currentBlockData = null; + } + containerElement.style.pointerEvents = 'none'; + containerElement.style.display = 'none'; // Hide the container + containerElement.innerHTML = ''; // Clear the container + + // Show the main canvas + canvas.style.display = 'block'; + draw(); +} + + +function initializeApp() { + canvas = document.getElementById('canvas'); + if (!canvas) { + console.error('Canvas element not found'); + return; + } + ctx = canvas.getContext('2d'); + if (!ctx) { + console.error('Unable to get 2D context from canvas'); + return; + } + canvas.width = window.innerWidth; + canvas.height = window.innerHeight; + + // Add event listeners + canvas.addEventListener('mousedown', handleMouseEvent); + canvas.addEventListener('mouseup', handleMouseEvent); + canvas.addEventListener('mousemove', handleMouseEvent); + + containerElement = document.getElementById('visualization-container'); + if (!containerElement) { + console.error('Visualization container element not found'); + return; + } + + // Set up the visualization container + containerElement.style.pointerEvents = 'none'; + containerElement.style.display = 'none'; // Initially hide the container + + fetchData(); +} + +// Slider class +class Slider { + constructor(x, y, width, height, label, min_value = 0, max_value = 100) { + this.rect = { x, y, width, height }; + this.label = label; + this.min = min_value; + this.max = max_value; + this.value = min_value; + this.grabbed = false; + this.enabled = true; + } + + draw(ctx) { + if (!this.enabled) return; + ctx.fillStyle = '#3c3c46'; + ctx.fillRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + const buttonX = this.rect.x + (this.value - this.min) / (this.max - this.min) * this.rect.width; + ctx.fillStyle = '#c8c8c8'; + ctx.fillRect(buttonX - 5, this.rect.y - 2, 10, this.rect.height + 4); + + ctx.fillStyle = '#c8c8c8'; + ctx.font = '18px Arial'; + ctx.fillText(this.label, this.rect.x, this.rect.y - 10); + ctx.fillText(this.value.toString(), this.rect.x + this.rect.width + 10, this.rect.y + this.rect.height / 2 + 5); + } + + handleEvent(event) { + if (!this.enabled) return; + if (event.type === 'mousedown') { + if (this.isPointInside(event.offsetX, event.offsetY)) { + this.grabbed = true; + } + } else if (event.type === 'mouseup') { + this.grabbed = false; + } else if (event.type === 'mousemove' && this.grabbed) { + const mouseX = event.offsetX; + this.value = Math.round((mouseX - this.rect.x) / this.rect.width * (this.max - this.min) + this.min); + this.value = Math.max(this.min, Math.min(this.max, this.value)); + } + } + + isPointInside(x, y) { + return x >= this.rect.x && x <= this.rect.x + this.rect.width && + y >= this.rect.y && y <= this.rect.y + this.rect.height; + } +} + +// Button class +class Button { + constructor(x, y, width, height, text) { + this.rect = { x, y, width, height }; + this.text = text; + this.color = '#3c3c46'; + this.hoverColor = '#50505a'; + this.clickColor = '#64646e'; + this.isHovered = false; + this.isClicked = false; + this.clickTime = 0; + } + + draw(ctx) { + let color = this.color; + if (this.isClicked && Date.now() - this.clickTime < 100) { + color = this.clickColor; + } else if (this.isHovered) { + color = this.hoverColor; + } + + ctx.fillStyle = color; + ctx.fillRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + ctx.strokeStyle = '#c8c8c8'; + ctx.strokeRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + + ctx.fillStyle = '#c8c8c8'; + ctx.font = '18px Arial'; + ctx.textAlign = 'center'; + ctx.textBaseline = 'middle'; + ctx.fillText(this.text, this.rect.x + this.rect.width / 2, this.rect.y + this.rect.height / 2); + ctx.textAlign = 'left'; + ctx.textBaseline = 'alphabetic'; + } + + handleEvent(event) { + const { offsetX, offsetY } = event; + this.isHovered = this.isPointInside(offsetX, offsetY); + if (event.type === 'mousedown' && this.isHovered) { + this.isClicked = true; + this.clickTime = Date.now(); + console.log(`Button '${this.text}' clicked!`); + } else if (event.type === 'mouseup') { + this.isClicked = false; + } + } + + isPointInside(x, y) { + return x >= this.rect.x && x <= this.rect.x + this.rect.width && + y >= this.rect.y && y <= this.rect.y + this.rect.height; + } +} + +// KernelGrid class +class KernelGrid { + constructor(x, y, width, height, gridSize, visualizationData) { + this.rect = { x, y, width, height }; + this.gridSize = gridSize; + this.visualizationData = visualizationData; + this.currentZ = 0; + this.blocks = []; + this.calculateBlockSize(); + this.createBlocks(); + this.selectedBlock = null; + } + + calculateBlockSize() { + this.blockWidth = Math.floor(this.rect.width / this.gridSize[0]) - 1; + this.blockHeight = Math.floor(this.rect.height / this.gridSize[1]) - 1; + } + + createBlocks() { + this.blocks = []; + for (let y = 0; y < this.gridSize[1]; y++) { + for (let x = 0; x < this.gridSize[0]; x++) { + const blockX = this.rect.x + x * (this.blockWidth + 1); + const blockY = this.rect.y + y * (this.blockHeight + 1); + const gridKey = `(${x}, ${y}, ${this.currentZ})`; + const blockData = this.visualizationData[gridKey] || []; + const block = new GridBlock(blockX, blockY, this.blockWidth, this.blockHeight, x, y, this.currentZ, blockData); + this.blocks.push(block); + } + } + } + + draw(ctx) { + ctx.fillStyle = '#F0F0F0'; + ctx.fillRect(this.rect.x, this.rect.y, this.rect.width, this.rect.height); + this.blocks.forEach(block => block.draw(ctx)); + } + + updateZ(z) { + this.currentZ = z; + this.blocks.forEach(block => { + block.gridPosition.z = z; + const gridKey = `(${block.gridPosition.x}, ${block.gridPosition.y}, ${z})`; + block.blockData = this.visualizationData[gridKey] || []; + }); + } + + handleClick(x, y, containerElement) { + const clickedBlock = this.blocks.find(block => block.isPointInside(x, y)); + if (clickedBlock) { + console.log(`Clicked block at (${clickedBlock.gridPosition.x}, ${clickedBlock.gridPosition.y}, ${clickedBlock.gridPosition.z})`); + if (this.selectedBlock) { + this.selectedBlock.hideDetailedView(); + } + this.selectedBlock = clickedBlock; + clickedBlock.showDetailedView(containerElement); + return clickedBlock; + } + return null; + } + + handleMouseMove(x, y) { + this.blocks.forEach(block => block.handleMouseMove(x, y)); + } +} + +// Function to determine the max values for sliders based on visualization data +function determineMaxValues(visualizationData) { + maxX = 0; + maxY = 0; + maxZ = 0; + const keys = Object.keys(visualizationData); + keys.forEach(key => { + const [x, y, z] = key.replace(/[()]/g, '').split(', ').map(Number); + if (x > maxX) maxX = x; + if (y > maxY) maxY = y; + if (z > maxZ) maxZ = z; + }); +} + +// Function to initialize UI elements +function initializeUIElements() { + sliders = [ + new Slider(1300, 50, 250, 20, "Program Id 0", 0, maxX), + new Slider(1300, 120, 250, 20, "Program Id 1", 0, maxY), + new Slider(1300, 190, 250, 20, "Program Id 2", 0, maxZ) + ]; + + zSlider = new Slider(50, 860, 1200, 20, "Z-axis", 0, maxZ); + zSlider.enabled = maxZ > 0; + + precomputeButton = new Button(1300, 260, 250, 40, "Precompute"); + kernelGrid = new KernelGrid(50, 50, 1200, 800, [maxX + 1, maxY + 1, maxZ + 1], globalData.ops.visualization_data); + backButton = new Button(50, 50, 100, 40, "Back"); + isInitialized = true; +} +function switchToTensorView(clickedBlock) { + currentView = 'tensor'; + currentBlockData = clickedBlock; + console.log("Switched to tensor view. Block data:", clickedBlock); + + containerElement.style.pointerEvents = 'auto'; + containerElement.style.display = 'block'; // Make sure the container is visible + clickedBlock.showDetailedView(containerElement); + + // Hide the main canvas + canvas.style.display = 'none'; +} + + + + +function handleMouseEvent(event) { + if (!isInitialized) { + console.warn('UI elements not initialized yet'); + return; + } + + const { offsetX, offsetY } = event; + if (currentView === 'main') { + sliders.forEach(slider => slider.handleEvent(event)); + if (zSlider && zSlider.enabled) { + zSlider.handleEvent(event); + if (kernelGrid) { + kernelGrid.updateZ(zSlider.value); + } + } + if (precomputeButton) { + precomputeButton.handleEvent(event); + } + if (kernelGrid) { + kernelGrid.handleMouseMove(offsetX, offsetY); + if (event.type === 'mousedown') { + const clickedBlock = kernelGrid.handleClick(offsetX, offsetY, containerElement); + if (clickedBlock) { + switchToTensorView(clickedBlock); + } + } + } + } else if (currentView === 'tensor') { + if (backButton) { + backButton.handleEvent(event); + if (event.type === 'mousedown' && backButton.isHovered) { + switchToMainView(); + } + } + } + draw(); +} + +function draw() { + if (!ctx) { + console.error('Canvas context not available'); + return; + } + + ctx.fillStyle = '#1e1e28'; + ctx.fillRect(0, 0, canvas.width, canvas.height); + + if (currentView === 'main') { + if (kernelGrid) kernelGrid.draw(ctx); + sliders.forEach(slider => slider.draw(ctx)); + if (zSlider && zSlider.enabled) { + zSlider.draw(ctx); + } + if (precomputeButton) precomputeButton.draw(ctx); + } + // Remove the else clause for 'tensor' view, as it's now handled by Three.js +} + + + + +// Fetch data from the API and initialize the visualization +async function fetchData() { + + const response = await fetch('/api/data'); + globalData = await response.json(); + + + // Determine max values for sliders based on fetched data + determineMaxValues(globalData.ops.visualization_data); + + // Initialize UI elements + initializeUIElements(); + + // Initial draw + draw(); + +} + +// Start the application +initializeApp(); \ No newline at end of file diff --git a/triton_viz/templates/index.html b/triton_viz/templates/index.html new file mode 100644 index 0000000..3a6db82 --- /dev/null +++ b/triton_viz/templates/index.html @@ -0,0 +1,47 @@ + + + + + +