diff --git a/triton_viz/interpreter.py b/triton_viz/interpreter.py index f447bbf..a572e1e 100644 --- a/triton_viz/interpreter.py +++ b/triton_viz/interpreter.py @@ -343,8 +343,6 @@ def wrapper(input, axis=None, keep_dims=False): def patch(): old_grid_executor_call = GridExecutor.__call__ old_jit_function_call = JITFunction.__call__ - # XXX(Keren): Temporarily disable rewriting of AST - old_rewrite_ast = InterpretedFunction._rewrite_ast old_create_make_range = interpreter_builder.create_make_range old_create_masked_load = interpreter_builder.create_masked_load old_create_expand_dims = interpreter_builder.create_expand_dims @@ -373,7 +371,6 @@ def patch(): finally: GridExecutor.__call__ = old_grid_executor_call JITFunction.__call__ = old_jit_function_call - InterpretedFunction._rewrite_ast = old_rewrite_ast interpreter_builder.create_make_range = old_create_make_range interpreter_builder.create_masked_load = old_create_masked_load interpreter_builder.create_expand_dims = old_create_expand_dims