From f198419ef6a11bdb543e3b0fda04f8b0e983998d Mon Sep 17 00:00:00 2001 From: Jokeren Date: Wed, 27 Mar 2024 12:42:28 -0400 Subject: [PATCH] Update --- triton_viz/interpreter.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/triton_viz/interpreter.py b/triton_viz/interpreter.py index 0d76c41..12ac979 100644 --- a/triton_viz/interpreter.py +++ b/triton_viz/interpreter.py @@ -343,12 +343,20 @@ def patch(): old_create_dot = interpreter_builder.create_dot old_create_masked_store = interpreter_builder.create_masked_store GridExecutor.__call__ = _grid_executor_call - interpreter_builder.create_make_range = _create_make_range(interpreter_builder.create_make_range) - interpreter_builder.create_masked_load = _create_masked_load(interpreter_builder.create_masked_load) - interpreter_builder.create_expand_dims = _create_expand_dims(interpreter_builder.create_expand_dims) + interpreter_builder.create_make_range = _create_make_range( + interpreter_builder.create_make_range + ) + interpreter_builder.create_masked_load = _create_masked_load( + interpreter_builder.create_masked_load + ) + interpreter_builder.create_expand_dims = _create_expand_dims( + interpreter_builder.create_expand_dims + ) interpreter_builder.binary_op = _create_binary_op(interpreter_builder.binary_op) interpreter_builder.create_dot = _create_dot(interpreter_builder.create_dot) - interpreter_builder.create_masked_store = _create_masked_store(interpreter_builder.create_masked_store) + interpreter_builder.create_masked_store = _create_masked_store( + interpreter_builder.create_masked_store + ) try: yield finally: