diff --git a/triton_viz/interpreter.py b/triton_viz/interpreter.py index 8a92ea8..c40bb36 100644 --- a/triton_viz/interpreter.py +++ b/triton_viz/interpreter.py @@ -134,15 +134,15 @@ def _check_storage_contiguous(tensor): def _grid_executor_call(self, *args_dev, **kwargs): - args_hst = self._init_args_hst(args_dev) # Removes reserved keywords from kwargs kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} if kwargs.pop("warmup", False): return + args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) # Remaps core language functions to interpreted ones _patch_lang(self.fn) # Prepare call arguments - args = inspect.getcallargs(self.fn, *args_hst, **kwargs) + args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst) call_args = {} tensors = [] for name, arg in args.items(): @@ -180,7 +180,7 @@ def _grid_executor_call(self, *args_dev, **kwargs): record_builder.set_grid_idx(x, y, z) self.fn(**call_args) # Copy arguments back to propagate side-effects - self._restore_args_dev(args_dev, args_hst) + self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst) _unpatch_lang()