Skip to content

Commit

Permalink
Hotfix to sync with triton/main (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren authored Jul 9, 2024
1 parent 62ddb6f commit 589b115
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions triton_viz/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
_implicit_cvt,
RESERVED_KWS,
interpreter_builder,
InterpretedFunction,
)
from triton.runtime.interpreter import _patch_lang as triton_patch_lang
from triton.runtime import JITFunction
Expand Down Expand Up @@ -342,6 +343,8 @@ def wrapper(input, axis=None, keep_dims=False):
def patch():
old_grid_executor_call = GridExecutor.__call__
old_jit_function_call = JITFunction.__call__
# XXX(Keren): Temporarily disable rewriting of AST
old_rewrite_ast = InterpretedFunction._rewrite_ast
old_create_make_range = interpreter_builder.create_make_range
old_create_masked_load = interpreter_builder.create_masked_load
old_create_expand_dims = interpreter_builder.create_expand_dims
Expand All @@ -350,6 +353,7 @@ def patch():
old_create_masked_store = interpreter_builder.create_masked_store
GridExecutor.__call__ = _grid_executor_call
JITFunction.__call__ = _jit_function_call
InterpretedFunction._rewrite_ast = lambda self: self.fn
interpreter_builder.create_make_range = _create_make_range(
interpreter_builder.create_make_range
)
Expand All @@ -369,6 +373,7 @@ def patch():
finally:
GridExecutor.__call__ = old_grid_executor_call
JITFunction.__call__ = old_jit_function_call
InterpretedFunction._rewrite_ast = old_rewrite_ast
interpreter_builder.create_make_range = old_create_make_range
interpreter_builder.create_masked_load = old_create_masked_load
interpreter_builder.create_expand_dims = old_create_expand_dims
Expand Down

0 comments on commit 589b115

Please sign in to comment.