From 8a38218bda13c8a67250e062b212a7822ee961ae Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Thu, 30 May 2024 10:46:09 -0400 Subject: [PATCH] [INTERPRETER] Sync with triton main (#26) --- triton_viz/interpreter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/triton_viz/interpreter.py b/triton_viz/interpreter.py index 8a92ea8..c40bb36 100644 --- a/triton_viz/interpreter.py +++ b/triton_viz/interpreter.py @@ -134,15 +134,15 @@ def _check_storage_contiguous(tensor): def _grid_executor_call(self, *args_dev, **kwargs): - args_hst = self._init_args_hst(args_dev) # Removes reserved keywords from kwargs kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} if kwargs.pop("warmup", False): return + args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) # Remaps core language functions to interpreted ones _patch_lang(self.fn) # Prepare call arguments - args = inspect.getcallargs(self.fn, *args_hst, **kwargs) + args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst) call_args = {} tensors = [] for name, arg in args.items(): @@ -180,7 +180,7 @@ def _grid_executor_call(self, *args_dev, **kwargs): record_builder.set_grid_idx(x, y, z) self.fn(**call_args) # Copy arguments back to propagate side-effects - self._restore_args_dev(args_dev, args_hst) + self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst) _unpatch_lang()