Skip to content

Commit

Permalink
Add sharding support for compile and benchmark functions
Browse files Browse the repository at this point in the history
Signed-off-by: aviator19941 <[email protected]>
  • Loading branch information
aviator19941 committed Nov 1, 2024
1 parent bb3ccb2 commit 82eee2c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
16 changes: 15 additions & 1 deletion sharktank/sharktank/utils/export_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ def compile_to_vmfb(
f"--iree-hal-target-backends={self.iree_hal_target_backends}",
f"-o={vmfb_path}",
]
if self.tensor_parallelism_size > 1:
iree_hal_target_devices = " ".join(
f"--iree-hal-target-device=hip[{i}]"
for i in range(self.tensor_parallelism_size)
)
compile_args += iree_hal_target_devices
if hal_dump_path:
compile_args += [
f"--iree-hal-dump-executable-files-to={hal_dump_path}/files"
Expand Down Expand Up @@ -243,8 +249,16 @@ def iree_benchmark_vmfb(
"--hip_allow_inline_execution=true",
"--device_allocator=caching",
f"--module={vmfb_name}",
f"--parameters=model={irpa_path}",
]
if self.tensor_parallelism_size > 1:
base_irpa_path, _ = os.path.splitext(path)
params = " ".join(
f"--parameters=model={base_irpa_path}.rank{i}.irpa"
for i in range(self.tensor_parallelism_size)
)
else:
params = f"--parameters=model={irpa_path}"
benchmark_args.append(params)
benchmark_args += args
cmd = subprocess.list2cmdline(benchmark_args)
logging.getLogger().info(f"Launching run command:\n" f"cd {cwd} && {cmd}")
Expand Down
14 changes: 5 additions & 9 deletions sharktank/tests/models/llama/benchmark_amdgpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,11 +472,7 @@ def testBenchmark70B_f16_Decomposed(self):
cwd=self.repo_root,
)

@pytest.mark.xfail(
reason="'tm_tensor.attention' op query and mask batch dimension mismatch",
strict=True,
raises=IreeCompileException,
)
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
def testBenchmark70B_f16_Decodeposed(self):
output_file_name = self.dir_path_70b / "f16_torch"
output_mlir = self.llama70b_f16_decodeposed_artifacts.create_file(
Expand Down Expand Up @@ -714,7 +710,9 @@ def setUp(self):
]

@pytest.mark.xfail(
reason="Export with sharding failing", strict=True, raises=ExportMlirException
reason="error: 'util.global' op references a promised device that was not declared",
strict=True,
raises=IreeCompileException,
)
def testBenchmark405B_f16_Decomposed(self):
output_file_name = self.dir_path_405b / "f16_decomposed"
Expand Down Expand Up @@ -764,9 +762,7 @@ def testBenchmark405B_f16_Decomposed(self):
cwd=self.repo_root,
)

@pytest.mark.xfail(
reason="Test not yet implemented", strict=True, raises=ExportMlirException
)
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
def testBenchmark405B_f16_Decodeposed(self):
output_file_name = self.dir_path_405b / "f16_torch"
output_mlir = self.llama405b_f16_decodeposed_artifacts.create_file(
Expand Down

0 comments on commit 82eee2c

Please sign in to comment.