Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DEV] Fix and rename op overrider #60

Open
wants to merge 2 commits into
base: keren/v2.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions triton_viz/clients/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ def pre_store_callback(ptr, value, mask, cache_modifier, eviction_policy):
self._report_load_store_bytes("store", ptr, mask)

if isinstance(op, Load):
return pre_load_callback, None
return pre_load_callback, None, None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type hints also have to be modified.

Maybe we should add mypy checks for this.

def register_op_callback(self, op: Type[Op]) -> Tuple[Optional[Callable], Optional[Callable], Optional[Callable]]

A cleaner way is to define an explicit structure to store callbacks

elif isinstance(op, Store):
return pre_store_callback, None
return pre_store_callback, None, None

return None, None
return None, None, None

def finalize(self) -> list:
return [self.load_bytes, self.store_bytes]
8 changes: 4 additions & 4 deletions triton_viz/clients/sanitizer/sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,12 @@ def pre_load_callback(ptr, mask, other, cache_modifier, eviction_policy, is_vola
upper_bound = valid_addresses.max()
self._check_if_range_statisfy_constraints(lower_bound, upper_bound)

def op_load_callback(ptr, mask, other, cache_modifier, eviction_policy, is_volatile):
def op_load_overrider(ptr, mask, other, cache_modifier, eviction_policy, is_volatile):
dtype_tt = ptr.get_element_ty()
dtype_np = _get_np_dtype(dtype_tt)
return TensorHandle(np.zeros_like(ptr.data, dtype=dtype_np), dtype_tt)

def op_store_callback(ptr, value, mask, cache_modifier, eviction_policy):
def op_store_overrider(ptr, value, mask, cache_modifier, eviction_policy):
pass

def pre_store_callback(ptr, value, mask, cache_modifier, eviction_policy):
Expand All @@ -235,9 +235,9 @@ def pre_store_callback(ptr, value, mask, cache_modifier, eviction_policy):
self._check_if_range_statisfy_constraints(lower_bound, upper_bound)

if op_type is Load:
return pre_load_callback, None, op_load_callback
return pre_load_callback, None, op_load_overrider
elif op_type is Store:
return pre_store_callback, None, op_store_callback
return pre_store_callback, None, op_store_overrider
else:
return None, None, None

Expand Down
10 changes: 5 additions & 5 deletions triton_viz/clients/tracer/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ def post_dot_callback(ret, input, other, *args):
self.records.append(Dot(input_shape, other_shape, ret_shape))

if op_type is Load:
return pre_load_callback, None
return pre_load_callback, None, None
elif op_type is Store:
return pre_store_callback, None
return pre_store_callback, None, None
elif op_type is ReduceSum:
return None, post_reduce_sum_callback
return None, post_reduce_sum_callback, None
elif op_type is Dot:
return None, post_dot_callback
return None, post_dot_callback, None

return None, None
return None, None, None

def finalize(self) -> list:
return self.records
4 changes: 2 additions & 2 deletions triton_viz/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def patch(self):
with patch_calls():
for client in self.clients:
for op in op_list:
before_callback, after_callback, op_callback = client.register_op_callback(op)
patch_op(op, before_callback, after_callback, op_callback)
before_callback, after_callback, op_overrider = client.register_op_callback(op)
patch_op(op, before_callback, after_callback, op_overrider)
try:
yield
finally:
Expand Down
14 changes: 7 additions & 7 deletions triton_viz/core/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@


class PatchOp:
def __init__(self, op, before_callback, after_callback, op_callback):
def __init__(self, op, before_callback, after_callback, op_overrider):
self.op = op
self.before_callback = before_callback
self.after_callback = after_callback
self.op_callback = op_callback
self.op_overrider = op_overrider

def __call__(self, *args, **kwargs):
if self.before_callback:
self.before_callback(*args, **kwargs)
if self.op_callback:
ret = self.op_callback(*args, **kwargs)
if self.op_overrider:
ret = self.op_overrider(*args, **kwargs)
else:
ret = self.op(*args, **kwargs)
if self.after_callback:
Expand All @@ -50,7 +50,7 @@ def __call__(self, *args, **kwargs):
return ret


def patch_op(op_type: Type[Op], before_callback: Callable, after_callback: Callable, op_callback: Callable):
def patch_op(op_type: Type[Op], before_callback: Callable, after_callback: Callable, op_overrider: Callable):
"""
Register a callback to be called before and after an operator is executed.

Expand All @@ -62,12 +62,12 @@ def patch_op(op_type: Type[Op], before_callback: Callable, after_callback: Calla
# create a new function that calls the before_callback, the original op and the after_callback
op_name = original_ops[op_type].__name__
current_op = getattr(interpreter_builder, op_name)
patched_op = PatchOp(current_op, before_callback, after_callback, op_callback)
patched_op = PatchOp(current_op, before_callback, after_callback, op_overrider)
setattr(interpreter_builder, op_name, lambda *args, **kwargs: patched_op(*args, **kwargs))
elif op_type in [ReduceMax, ReduceMin, ReduceSum]:
op_name = reduce_map[op_type].__name__
current_op = getattr(tl, op_name)
patched_op = PatchOp(current_op, before_callback, after_callback, op_callback)
patched_op = PatchOp(current_op, before_callback, after_callback, op_overrider)
setattr(tl, op_name, lambda *args, **kwargs: patched_op(*args, **kwargs))
else:
raise ValueError(f"Patching operator {op_type} not supported")
Expand Down