diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index 62f936459..bb6659b8a 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -204,6 +204,12 @@ def compile_to_vmfb( f"--iree-hal-target-backends={self.iree_hal_target_backends}", f"-o={vmfb_path}", ] + if self.tensor_parallelism_size > 1: + iree_hal_target_devices = " ".join( + f"--iree-hal-target-device=hip[{i}]" + for i in range(self.tensor_parallelism_size) + ) + compile_args += iree_hal_target_devices if hal_dump_path: compile_args += [ f"--iree-hal-dump-executable-files-to={hal_dump_path}/files" @@ -243,8 +249,16 @@ def iree_benchmark_vmfb( "--hip_allow_inline_execution=true", "--device_allocator=caching", f"--module={vmfb_name}", - f"--parameters=model={irpa_path}", ] + if self.tensor_parallelism_size > 1: + base_irpa_path, _ = os.path.splitext(path) + params = " ".join( + f"--parameters=model={base_irpa_path}.rank{i}.irpa" + for i in range(self.tensor_parallelism_size) + ) + else: + params = f"--parameters=model={irpa_path}" + benchmark_args.append(params) benchmark_args += args cmd = subprocess.list2cmdline(benchmark_args) logging.getLogger().info(f"Launching run command:\n" f"cd {cwd} && {cmd}") diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py index be28f7402..389b6c7cb 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -472,11 +472,7 @@ def testBenchmark70B_f16_Decomposed(self): cwd=self.repo_root, ) - @pytest.mark.xfail( - reason="'tm_tensor.attention' op query and mask batch dimension mismatch", - strict=True, - raises=IreeCompileException, - ) + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def testBenchmark70B_f16_Decodeposed(self): output_file_name = self.dir_path_70b / "f16_torch" output_mlir = self.llama70b_f16_decodeposed_artifacts.create_file( @@ -714,7 +710,9 @@ def setUp(self): ] @pytest.mark.xfail( - reason="Export with sharding failing", strict=True, raises=ExportMlirException + reason="error: 'util.global' op references a promised device that was not declared", + strict=True, + raises=IreeCompileException, ) def testBenchmark405B_f16_Decomposed(self): output_file_name = self.dir_path_405b / "f16_decomposed" @@ -764,9 +762,7 @@ def testBenchmark405B_f16_Decomposed(self): cwd=self.repo_root, ) - @pytest.mark.xfail( - reason="Test not yet implemented", strict=True, raises=ExportMlirException - ) + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def testBenchmark405B_f16_Decodeposed(self): output_file_name = self.dir_path_405b / "f16_torch" output_mlir = self.llama405b_f16_decodeposed_artifacts.create_file(