From 1760d9fd00ec59b009ee4f514a09b1c28eff4a1e Mon Sep 17 00:00:00 2001 From: Sasha Date: Tue, 19 Mar 2024 10:39:04 -0400 Subject: [PATCH 1/2] add oob examples --- triton_viz/data.py | 4 ++++ triton_viz/draw.py | 21 +++++++++++++-------- triton_viz/interface.py | 9 ++++++++- triton_viz/interpreter.py | 22 ++++++++++++++++------ 4 files changed, 41 insertions(+), 15 deletions(-) diff --git a/triton_viz/data.py b/triton_viz/data.py index 76b9e8e..fb5dc72 100644 --- a/triton_viz/data.py +++ b/triton_viz/data.py @@ -33,6 +33,8 @@ class Store(Op): offsets: npt.NDArray[np.int_] masks: npt.NDArray[np.bool_] invalid_access_masks: npt.NDArray[np.bool_] + original_offsets: npt.NDArray[np.int_] + original_mask: npt.NDArray[np.bool_] @dataclass @@ -42,6 +44,8 @@ class Load(Op): offsets: npt.NDArray[np.int_] masks: npt.NDArray[np.bool_] invalid_access_masks: npt.NDArray[np.bool_] + original_offsets: npt.NDArray[np.int_] + original_mask: npt.NDArray[np.bool_] @dataclass diff --git a/triton_viz/draw.py b/triton_viz/draw.py index 1ccdb53..f010196 100644 --- a/triton_viz/draw.py +++ b/triton_viz/draw.py @@ -61,15 +61,15 @@ def reshape(d: Diagram) -> Diagram: def collect_grid(): for launch in record_builder.launches[-1:]: - records, tensor_table = collect_launch(launch) - return records, tensor_table + records, tensor_table, failures = collect_launch(launch) + return records, tensor_table, failures def collect_launch(launch): tensor_table = {} for i, t in enumerate(launch.tensors): tensor_table[t.ptr] = (t, i) - + failures = {} all_grids = {} last_grid = None program_records = [] @@ -80,8 +80,11 @@ def collect_launch(launch): program_records = [] last_grid = r program_records.append(r) + if isinstance(r, Store) or isinstance(r, Load): + if (r.invalid_access_masks & r.original_mask).any(): + failures[last_grid.idx] = True all_grids[last_grid.idx] = program_records - return all_grids, tensor_table + return all_grids, tensor_table, failures def draw_record(program_record, tensor_table, output): @@ -230,13 +233,15 @@ def store_load( x: Union[Store, Load], tensor_table: Dict[int, Tuple[Tensor, int]] ) -> Tuple[Diagram, Diagram]: tensor, tensor_id = tensor_table[x.ptr] - # inp = base_tensor(tensor.shape, DEFAULT) color = ACTIVE[tensor_id] - inp = cover(tensor.shape, tensor.dtype, x.offsets, x.masks, color) + invalid = x.invalid_access_masks.any() + if invalid: + color = Color("red") + inp = cover(tensor.shape, tensor.dtype, x.original_offsets, x.original_mask, color) inp = reshape(inp) - s = make_3d(x.offsets.shape) - a, b, c = x.masks.reshape(*s).nonzero() + s = make_3d(x.original_offsets.shape) + a, b, c = x.original_mask.reshape(*s).nonzero() out = draw_tensor_3d(s, a, b, c, color) return inp, out diff --git a/triton_viz/interface.py b/triton_viz/interface.py index b07785a..ce8364f 100644 --- a/triton_viz/interface.py +++ b/triton_viz/interface.py @@ -5,7 +5,7 @@ def launch(share=True): cache = {} - program_records, tt = triton_viz.collect_grid() + program_records, tt, failures = triton_viz.collect_grid() m = [0, 0, 0] size = [0, 0] for k in program_records.keys(): @@ -35,6 +35,12 @@ def launch(share=True): s2 = gr.Slider(0, m[1] - 1, value=0, step=1, label="Program Id 1") s3 = gr.Slider(0, m[2] - 1, value=0, step=1, label="Program Id 2") b1 = gr.Button("Precompute") + if failures: + gr.Label( + show_label=False, + value="Invalid memory access in " + + " ".join(str(list(failures.keys()))), + ) def cache_block(idx): name = tempfile.NamedTemporaryFile(suffix=".svg") @@ -81,3 +87,4 @@ def precompute(inp): demo.load(update, inputs={s1, s2, s3}, outputs=[img, b1]) demo.launch(share=share, debug=False, height=800, quiet=True, show_api=False) + return failures diff --git a/triton_viz/interpreter.py b/triton_viz/interpreter.py index 6059153..8872743 100644 --- a/triton_viz/interpreter.py +++ b/triton_viz/interpreter.py @@ -120,7 +120,6 @@ def _check_storage_contiguous(tensor): # Note that this is different from if a tensor is accessed contiguously, so we cannot use tensor.is_contiguous() # 1. Sort strides from smallest to largest # 2. If the tensor is contiguous, the stride product should be the same of the shape product of all previous dimensions - stride_prod = 1 shape_prod = 1 indices = sorted(range(len(tensor.stride())), key=tensor.stride().__getitem__) for i, index in enumerate(indices): @@ -128,8 +127,7 @@ def _check_storage_contiguous(tensor): shape = tensor.shape[index] if i == 0 and stride != 1: return False - stride_prod *= stride - if i != 0 and stride_prod != shape_prod: + if i != 0 and stride != shape_prod: return False shape_prod *= shape return True @@ -194,7 +192,13 @@ def check_out_of_bounds_access(ptrs): valid_access_mask = (offsets >= 0) & (offsets < max_valid_offset) invalid_access_mask = np.logical_not(valid_access_mask) corrected_offsets = np.where(valid_access_mask, offsets, 0) - return tensor_ptr, valid_access_mask, invalid_access_mask, corrected_offsets + return ( + tensor_ptr, + valid_access_mask, + invalid_access_mask, + corrected_offsets, + offsets, + ) def _create_masked_load(fn): @@ -205,13 +209,16 @@ def wrapper(ptrs, mask, other, cache_modifier, eviction_policy, is_volatile): valid_access_mask, invalid_access_mask, corrected_offsets, + original_offsets, ) = check_out_of_bounds_access(ptrs) load_record = Load( ptr=tensor_ptr.ptr, shape=ptrs.data.shape, offsets=corrected_offsets, - masks=valid_access_mask, + masks=valid_access_mask & mask.data, invalid_access_masks=invalid_access_mask, + original_offsets=original_offsets, + original_mask=mask.data, ) record_builder.add_record(load_record) @@ -235,13 +242,16 @@ def wrapper(ptrs, mask, other, cache_modifier, eviction_policy): valid_access_mask, invalid_access_mask, corrected_offsets, + original_offsets, ) = check_out_of_bounds_access(ptrs) store_record = Store( ptr=tensor_ptr.ptr, shape=ptrs.data.shape, offsets=corrected_offsets, - masks=valid_access_mask, + masks=valid_access_mask & (mask.data == 1), invalid_access_masks=invalid_access_mask, + original_offsets=original_offsets, + original_mask=(mask.data == 1), ) record_builder.add_record(store_record) From d4f107c4b549865a4d8435a71beafd141269c958 Mon Sep 17 00:00:00 2001 From: Sasha Date: Tue, 19 Mar 2024 10:52:15 -0400 Subject: [PATCH 2/2] . --- triton_viz/draw.py | 2 +- triton_viz/interface.py | 9 ++++++--- triton_viz/interpreter.py | 8 ++++---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/triton_viz/draw.py b/triton_viz/draw.py index f010196..c0f114d 100644 --- a/triton_viz/draw.py +++ b/triton_viz/draw.py @@ -411,7 +411,7 @@ def make_cube(projection, start, end, color): for p, loc in outer2 ] line = [ - (p.stroke().line_width(0.001).line_color(GREY), l_) + (p.stroke().line_width(0.001).line_color(BLACK), l_) for loc in ls for p, l_ in loc ] diff --git a/triton_viz/interface.py b/triton_viz/interface.py index ce8364f..0552c0a 100644 --- a/triton_viz/interface.py +++ b/triton_viz/interface.py @@ -35,11 +35,14 @@ def launch(share=True): s2 = gr.Slider(0, m[1] - 1, value=0, step=1, label="Program Id 1") s3 = gr.Slider(0, m[2] - 1, value=0, step=1, label="Program Id 2") b1 = gr.Button("Precompute") + gr.Markdown(f"## Program Ids: {tuple(m)}") + if failures: - gr.Label( + gr.Markdown( show_label=False, - value="Invalid memory access in " - + " ".join(str(list(failures.keys()))), + value="## Invalid memory access in " + + "\n * " + + "\n* ".join(list(map(str, failures.keys()))), ) def cache_block(idx): diff --git a/triton_viz/interpreter.py b/triton_viz/interpreter.py index 8872743..2944b06 100644 --- a/triton_viz/interpreter.py +++ b/triton_viz/interpreter.py @@ -236,7 +236,7 @@ def wrapper(ptrs, mask, other, cache_modifier, eviction_policy, is_volatile): def _create_masked_store(fn): @wraps(fn) - def wrapper(ptrs, mask, other, cache_modifier, eviction_policy): + def wrapper(ptrs, value, mask, cache_modifier, eviction_policy): ( tensor_ptr, valid_access_mask, @@ -248,14 +248,14 @@ def wrapper(ptrs, mask, other, cache_modifier, eviction_policy): ptr=tensor_ptr.ptr, shape=ptrs.data.shape, offsets=corrected_offsets, - masks=valid_access_mask & (mask.data == 1), + masks=valid_access_mask & mask.data, invalid_access_masks=invalid_access_mask, original_offsets=original_offsets, - original_mask=(mask.data == 1), + original_mask=mask.data, ) record_builder.add_record(store_record) - return fn(ptrs, mask, other, cache_modifier, eviction_policy) + return fn(ptrs, value, mask, cache_modifier, eviction_policy) return wrapper