From b7661838408e4c7ebd6151944b03dc5672f110ec Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 11 Dec 2024 22:40:28 -0500 Subject: [PATCH] fix and rename op_overrider --- triton_viz/clients/profiler/profiler.py | 6 +++--- triton_viz/clients/sanitizer/sanitizer.py | 8 ++++---- triton_viz/clients/tracer/tracer.py | 10 +++++----- triton_viz/core/client.py | 4 ++-- triton_viz/core/patch.py | 14 +++++++------- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/triton_viz/clients/profiler/profiler.py b/triton_viz/clients/profiler/profiler.py index 5eeda75..5b9a3d1 100644 --- a/triton_viz/clients/profiler/profiler.py +++ b/triton_viz/clients/profiler/profiler.py @@ -43,11 +43,11 @@ def pre_store_callback(ptr, value, mask, cache_modifier, eviction_policy): self._report_load_store_bytes("store", ptr, mask) if isinstance(op, Load): - return pre_load_callback, None + return pre_load_callback, None, None elif isinstance(op, Store): - return pre_store_callback, None + return pre_store_callback, None, None - return None, None + return None, None, None def finalize(self) -> list: return [self.load_bytes, self.store_bytes] diff --git a/triton_viz/clients/sanitizer/sanitizer.py b/triton_viz/clients/sanitizer/sanitizer.py index e55ad5e..0640772 100644 --- a/triton_viz/clients/sanitizer/sanitizer.py +++ b/triton_viz/clients/sanitizer/sanitizer.py @@ -216,12 +216,12 @@ def pre_load_callback(ptr, mask, other, cache_modifier, eviction_policy, is_vola upper_bound = valid_addresses.max() self._check_if_range_statisfy_constraints(lower_bound, upper_bound) - def op_load_callback(ptr, mask, other, cache_modifier, eviction_policy, is_volatile): + def op_load_overrider(ptr, mask, other, cache_modifier, eviction_policy, is_volatile): dtype_tt = ptr.get_element_ty() dtype_np = _get_np_dtype(dtype_tt) return TensorHandle(np.zeros_like(ptr.data, dtype=dtype_np), dtype_tt) - def op_store_callback(ptr, value, mask, cache_modifier, eviction_policy): + def op_store_overrider(ptr, value, mask, cache_modifier, eviction_policy): pass def pre_store_callback(ptr, value, mask, cache_modifier, eviction_policy): @@ -235,9 +235,9 @@ def pre_store_callback(ptr, value, mask, cache_modifier, eviction_policy): self._check_if_range_statisfy_constraints(lower_bound, upper_bound) if op_type is Load: - return pre_load_callback, None, op_load_callback + return pre_load_callback, None, op_load_overrider elif op_type is Store: - return pre_store_callback, None, op_store_callback + return pre_store_callback, None, op_store_overrider else: return None, None, None diff --git a/triton_viz/clients/tracer/tracer.py b/triton_viz/clients/tracer/tracer.py index d95d80b..c5f9595 100644 --- a/triton_viz/clients/tracer/tracer.py +++ b/triton_viz/clients/tracer/tracer.py @@ -78,15 +78,15 @@ def post_dot_callback(ret, input, other, *args): self.records.append(Dot(input_shape, other_shape, ret_shape)) if op_type is Load: - return pre_load_callback, None + return pre_load_callback, None, None elif op_type is Store: - return pre_store_callback, None + return pre_store_callback, None, None elif op_type is ReduceSum: - return None, post_reduce_sum_callback + return None, post_reduce_sum_callback, None elif op_type is Dot: - return None, post_dot_callback + return None, post_dot_callback, None - return None, None + return None, None, None def finalize(self) -> list: return self.records diff --git a/triton_viz/core/client.py b/triton_viz/core/client.py index 5f22d86..d904858 100644 --- a/triton_viz/core/client.py +++ b/triton_viz/core/client.py @@ -41,8 +41,8 @@ def patch(self): with patch_calls(): for client in self.clients: for op in op_list: - before_callback, after_callback, op_callback = client.register_op_callback(op) - patch_op(op, before_callback, after_callback, op_callback) + before_callback, after_callback, op_overrider = client.register_op_callback(op) + patch_op(op, before_callback, after_callback, op_overrider) try: yield finally: diff --git a/triton_viz/core/patch.py b/triton_viz/core/patch.py index 4f9c6cd..9a8132f 100644 --- a/triton_viz/core/patch.py +++ b/triton_viz/core/patch.py @@ -31,17 +31,17 @@ class PatchOp: - def __init__(self, op, before_callback, after_callback, op_callback): + def __init__(self, op, before_callback, after_callback, op_overrider): self.op = op self.before_callback = before_callback self.after_callback = after_callback - self.op_callback = op_callback + self.op_overrider = op_overrider def __call__(self, *args, **kwargs): if self.before_callback: self.before_callback(*args, **kwargs) - if self.op_callback: - ret = self.op_callback(*args, **kwargs) + if self.op_overrider: + ret = self.op_overrider(*args, **kwargs) else: ret = self.op(*args, **kwargs) if self.after_callback: @@ -50,7 +50,7 @@ def __call__(self, *args, **kwargs): return ret -def patch_op(op_type: Type[Op], before_callback: Callable, after_callback: Callable, op_callback: Callable): +def patch_op(op_type: Type[Op], before_callback: Callable, after_callback: Callable, op_overrider: Callable): """ Register a callback to be called before and after an operator is executed. @@ -62,12 +62,12 @@ def patch_op(op_type: Type[Op], before_callback: Callable, after_callback: Calla # create a new function that calls the before_callback, the original op and the after_callback op_name = original_ops[op_type].__name__ current_op = getattr(interpreter_builder, op_name) - patched_op = PatchOp(current_op, before_callback, after_callback, op_callback) + patched_op = PatchOp(current_op, before_callback, after_callback, op_overrider) setattr(interpreter_builder, op_name, lambda *args, **kwargs: patched_op(*args, **kwargs)) elif op_type in [ReduceMax, ReduceMin, ReduceSum]: op_name = reduce_map[op_type].__name__ current_op = getattr(tl, op_name) - patched_op = PatchOp(current_op, before_callback, after_callback, op_callback) + patched_op = PatchOp(current_op, before_callback, after_callback, op_overrider) setattr(tl, op_name, lambda *args, **kwargs: patched_op(*args, **kwargs)) else: raise ValueError(f"Patching operator {op_type} not supported")