Skip to content

Commit

Permalink
Bug fix for building triton kernel
Browse files Browse the repository at this point in the history
Summary:
This diff includes a few bug fixes:

1. a bug in build_triton_func

2. https://www.internalfb.com/diff/D59195933 changed the format of grid information for triton kernel. et_replay also need to change the triton kernel execution function to pass in grid information correctly

3. updated resnet test files to have process_group:init node.

4. updated GPT2 PT2 trace file to have the new format for grid information

Reviewed By: briancoutinho

Differential Revision: D59195828
  • Loading branch information
shengfukevin authored and facebook-github-bot committed Jul 12, 2024
1 parent 9391736 commit c9829a7
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
4 changes: 2 additions & 2 deletions et_replay/lib/et_replay_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,15 +457,15 @@ def build_triton_func(n, resources_dir, async_compile, device):
with open(os.path.join(resources_dir, n.kernel_file), "r") as f:
code = f.read()

func = None
# TORCHINDUCTOR_UNIQUE_KERNEL_NAMES controls whether each triton
# kernel is given a unique name or not, if it is not, then the
# kernel name will be "triton_" for all triton kernels.
try:
func = async_compile.triton(n.name, code, device_str=device)
except Exception:
func = async_compile.triton("triton_", code, device_str=device)
finally:
func = None

return func, 0


Expand Down
7 changes: 1 addition & 6 deletions et_replay/tools/et_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,13 +1128,8 @@ def run_op(self, node, iter):
outputs = []
if output_count == 0:
if node.kernel_backend == "triton":
# remove the last comma
grid_info = inputs[-2]
index = grid_info.rfind(",")
if index >= 0:
grid_info = grid_info[:index] + grid_info[index + 1 :]
exec(
f"func.run(*inputs[:-2], grid={grid_info}, stream={inputs[-1]})"
f"func.run(*inputs[:-2], grid={inputs[-2]}, stream={inputs[-1]})"
)
else:
func(*inputs)
Expand Down

0 comments on commit c9829a7

Please sign in to comment.