Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren committed Mar 27, 2024
1 parent 5f7e4a2 commit d43bd5d
Showing 1 changed file with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions triton_viz/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
GridExecutor,
_implicit_cvt,
RESERVED_KWS,
builder,
interpreter_builder,
)
from typing import Tuple, List, Optional
from contextlib import contextmanager
Expand Down Expand Up @@ -169,14 +169,14 @@ def _grid_executor_call(self, *args_dev, **kwargs):
grid = self.grid(call_args) if callable(self.grid) else self.grid
assert len(grid) <= 3
grid = grid + (1,) * (3 - len(grid))
builder.set_grid_dim(*grid)
interpreter_builder.set_grid_dim(*grid)
record_builder.set_grid_dim(*grid)
record_builder.add_tensors(tensors)
record_builder.sort_tensor_handles()
for x in range(grid[0]):
for y in range(grid[1]):
for z in range(grid[2]):
builder.set_grid_idx(x, y, z)
interpreter_builder.set_grid_idx(x, y, z)
record_builder.set_grid_idx(x, y, z)
self.fn(**call_args)
# Copy arguments back to propagate side-effects
Expand Down Expand Up @@ -336,26 +336,26 @@ def wrapper(input, axis=None, keep_dims=False):
@contextmanager
def patch():
old_grid_executor_call = GridExecutor.__call__
old_create_make_range = builder.create_make_range
old_create_masked_load = builder.create_masked_load
old_create_expand_dims = builder.create_expand_dims
old_binary_op = builder.binary_op
old_create_dot = builder.create_dot
old_create_masked_store = builder.create_masked_store
old_create_make_range = interpreter_builder.create_make_range
old_create_masked_load = interpreter_builder.create_masked_load
old_create_expand_dims = interpreter_builder.create_expand_dims
old_binary_op = interpreter_builder.binary_op
old_create_dot = interpreter_builder.create_dot
old_create_masked_store = interpreter_builder.create_masked_store
GridExecutor.__call__ = _grid_executor_call
builder.create_make_range = _create_make_range(builder.create_make_range)
builder.create_masked_load = _create_masked_load(builder.create_masked_load)
builder.create_expand_dims = _create_expand_dims(builder.create_expand_dims)
builder.binary_op = _create_binary_op(builder.binary_op)
builder.create_dot = _create_dot(builder.create_dot)
builder.create_masked_store = _create_masked_store(builder.create_masked_store)
interpreter_builder.create_make_range = _create_make_range(interpreter_builder.create_make_range)
interpreter_builder.create_masked_load = _create_masked_load(interpreter_builder.create_masked_load)
interpreter_builder.create_expand_dims = _create_expand_dims(interpreter_builder.create_expand_dims)
interpreter_builder.binary_op = _create_binary_op(interpreter_builder.binary_op)
interpreter_builder.create_dot = _create_dot(interpreter_builder.create_dot)
interpreter_builder.create_masked_store = _create_masked_store(interpreter_builder.create_masked_store)
try:
yield
finally:
GridExecutor.__call__ = old_grid_executor_call
builder.create_make_range = old_create_make_range
builder.create_masked_load = old_create_masked_load
builder.create_expand_dims = old_create_expand_dims
builder.binary_op = old_binary_op
builder.create_dot = old_create_dot
builder.create_masked_store = old_create_masked_store
interpreter_builder.create_make_range = old_create_make_range
interpreter_builder.create_masked_load = old_create_masked_load
interpreter_builder.create_expand_dims = old_create_expand_dims
interpreter_builder.binary_op = old_binary_op
interpreter_builder.create_dot = old_create_dot
interpreter_builder.create_masked_store = old_create_masked_store

0 comments on commit d43bd5d

Please sign in to comment.