Skip to content

Commit

Permalink
[DEV] Fix masks and update reduce interfaces with triton/head (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren authored Mar 20, 2024
1 parent 11b34a8 commit 3d0e1b3
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 41 deletions.
6 changes: 3 additions & 3 deletions examples/vec_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_add():
if isinstance(op, Load):
result_offsets = op.offsets.tolist()
result_offsets_len = len(result_offsets)
result_masks = op.masks
result_masks = op.access_masks
result_invalid_masks = op.invalid_access_masks
break
assert torch.allclose(result, expected)
Expand All @@ -99,13 +99,13 @@ def test_out_of_bounds_add():
expected_offsets = [(i * t_size) if i < size else 0 for i in range(BLOCK_SIZE)]
expected_offsets_len = len(expected_offsets)
expected = input_vector1 + input_vector2
expected_masks = [True if i < size else False for i in range(BLOCK_SIZE)]
expected_masks = [i < size for i in range(BLOCK_SIZE)]
expected_invalid_masks = np.logical_not(expected_masks)
for op in record_builder.launches[0].records:
if isinstance(op, Load):
result_offsets = op.offsets.tolist()
result_offsets_len = len(result_offsets)
result_masks = op.masks
result_masks = op.access_masks
result_invalid_masks = op.invalid_access_masks
break
assert torch.allclose(result, expected)
Expand Down
8 changes: 4 additions & 4 deletions triton_viz/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,21 @@ class Store(Op):
ptr: int
shape: Tuple
offsets: npt.NDArray[np.int_]
masks: npt.NDArray[np.bool_]
access_masks: npt.NDArray[np.bool_]
invalid_access_masks: npt.NDArray[np.bool_]
original_offsets: npt.NDArray[np.int_]
original_mask: npt.NDArray[np.bool_]
original_masks: npt.NDArray[np.bool_]


@dataclass
class Load(Op):
ptr: int
shape: Tuple
offsets: npt.NDArray[np.int_]
masks: npt.NDArray[np.bool_]
access_masks: npt.NDArray[np.bool_]
invalid_access_masks: npt.NDArray[np.bool_]
original_offsets: npt.NDArray[np.int_]
original_mask: npt.NDArray[np.bool_]
original_masks: npt.NDArray[np.bool_]


@dataclass
Expand Down
12 changes: 7 additions & 5 deletions triton_viz/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ def collect_launch(launch):
program_records = []
last_grid = r
program_records.append(r)
if isinstance(r, Store) or isinstance(r, Load):
if (r.invalid_access_masks & r.original_mask).any():
failures[last_grid.idx] = True
if (
isinstance(r, (Store, Load))
and (r.invalid_access_masks & r.original_masks).any()
):
failures[last_grid.idx] = True
all_grids[last_grid.idx] = program_records
return all_grids, tensor_table, failures

Expand Down Expand Up @@ -238,10 +240,10 @@ def store_load(
invalid = x.invalid_access_masks.any()
if invalid:
color = Color("red")
inp = cover(tensor.shape, tensor.dtype, x.original_offsets, x.original_mask, color)
inp = cover(tensor.shape, tensor.dtype, x.original_offsets, x.original_masks, color)
inp = reshape(inp)
s = make_3d(x.original_offsets.shape)
a, b, c = x.original_mask.reshape(*s).nonzero()
a, b, c = x.original_masks.reshape(*s).nonzero()
out = draw_tensor_3d(s, a, b, c, color)
return inp, out

Expand Down
62 changes: 33 additions & 29 deletions triton_viz/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def _patch_lang(fn):
from triton.runtime.interpreter import _patch_lang as patch_lang

patch_lang(fn)
tl.sum = _create_reduce(tl.sum, "sum")
tl.min = _create_reduce(tl.min, "min")
tl.max = _create_reduce(tl.max, "max")
tl.sum = _create_reduce(tl.reduce, "sum")
tl.min = _create_reduce(tl.reduce, "min")
tl.max = _create_reduce(tl.reduce, "max")


def _unpatch_lang():
Expand Down Expand Up @@ -184,47 +184,47 @@ def _grid_executor_call(self, *args_dev, **kwargs):
_unpatch_lang()


def check_out_of_bounds_access(ptrs):
def check_out_of_bounds_access(ptrs, masks):
first_ptr = np.reshape(ptrs.data, (-1))[0]
tensor_ptr = record_builder.get_tensor_ptr(first_ptr)
offsets = ptrs.data - tensor_ptr.ptr
max_valid_offset = np.prod(tensor_ptr.shape) * tensor_ptr.element_size
valid_access_mask = (offsets >= 0) & (offsets < max_valid_offset)
invalid_access_mask = np.logical_not(valid_access_mask)
corrected_offsets = np.where(valid_access_mask, offsets, 0)
valid_access_masks = (offsets >= 0) & (offsets < max_valid_offset)
invalid_access_masks = (~valid_access_masks) & (~masks.data)
corrected_offsets = np.where(valid_access_masks, offsets, 0)
return (
tensor_ptr,
valid_access_mask,
invalid_access_mask,
valid_access_masks & masks.data,
invalid_access_masks,
corrected_offsets,
offsets,
)


def _create_masked_load(fn):
@wraps(fn)
def wrapper(ptrs, mask, other, cache_modifier, eviction_policy, is_volatile):
def wrapper(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile):
(
tensor_ptr,
valid_access_mask,
invalid_access_mask,
valid_access_masks,
invalid_access_masks,
corrected_offsets,
original_offsets,
) = check_out_of_bounds_access(ptrs)
) = check_out_of_bounds_access(ptrs, masks)
load_record = Load(
ptr=tensor_ptr.ptr,
shape=ptrs.data.shape,
offsets=corrected_offsets,
masks=valid_access_mask & mask.data,
invalid_access_masks=invalid_access_mask,
access_masks=valid_access_masks,
invalid_access_masks=invalid_access_masks,
original_offsets=original_offsets,
original_mask=mask.data,
original_masks=masks.data,
)
record_builder.add_record(load_record)

return fn(
ptrs,
mask,
masks,
other,
cache_modifier,
eviction_policy,
Expand All @@ -236,26 +236,26 @@ def wrapper(ptrs, mask, other, cache_modifier, eviction_policy, is_volatile):

def _create_masked_store(fn):
@wraps(fn)
def wrapper(ptrs, value, mask, cache_modifier, eviction_policy):
def wrapper(ptrs, value, masks, cache_modifier, eviction_policy):
(
tensor_ptr,
valid_access_mask,
invalid_access_mask,
valid_access_masks,
invalid_access_masks,
corrected_offsets,
original_offsets,
) = check_out_of_bounds_access(ptrs)
) = check_out_of_bounds_access(ptrs, masks)
store_record = Store(
ptr=tensor_ptr.ptr,
shape=ptrs.data.shape,
offsets=corrected_offsets,
masks=valid_access_mask & mask.data,
invalid_access_masks=invalid_access_mask,
access_masks=valid_access_masks,
invalid_access_masks=invalid_access_masks,
original_offsets=original_offsets,
original_mask=mask.data,
original_masks=masks.data,
)
record_builder.add_record(store_record)

return fn(ptrs, value, mask, cache_modifier, eviction_policy)
return fn(ptrs, value, valid_access_masks, cache_modifier, eviction_policy)

return wrapper

Expand Down Expand Up @@ -311,11 +311,15 @@ def wrapper(arg, axis):
return wrapper


def _create_reduce(fn, op_name):
def _create_reduce(fn, op_name: str):
@wraps(fn)
def wrapper(input, axis=None, **kwargs):
ret = fn(input, axis=axis, **kwargs)
keep_dims = kwargs.get("keep_dims", False)
def wrapper(input, axis=None, keep_dims=False):
mapping = {
"max": tl.standard._elementwise_max,
"min": tl.standard._elementwise_min,
"sum": tl.standard._sum_combine,
}
ret = fn(input, axis=axis, combine_fn=mapping[op_name], keep_dims=keep_dims)
reduce_record = Reduce(
input_shape=input.handle.data.shape,
index=axis,
Expand Down

0 comments on commit 3d0e1b3

Please sign in to comment.