From 4849c769ebec34ae59812c3032fe5aa7022c78c1 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Mon, 6 Jan 2025 19:07:17 +0000 Subject: [PATCH] Use iree_hal_target_device flag to compile --- sharktank/sharktank/utils/export_artifacts.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index 226db6d73..4045e90a5 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -92,7 +92,7 @@ def __init__( attention_kernel: str, tensor_parallelism_size: int, block_seq_stride: int, - iree_hal_target_device: Optional[str] = None, + iree_hal_target_device: str, ): self.sharktank_dir = str( Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent @@ -220,11 +220,13 @@ def compile_to_vmfb( ] if self.tensor_parallelism_size > 1: iree_hal_target_devices = [ - f"--iree-hal-target-device=hip[{i}]" + f"--iree-hal-target-device={self.iree_hal_target_device}[{i}]" for i in range(self.tensor_parallelism_size) ] else: - iree_hal_target_devices = ["--iree-hal-target-device=hip"] + iree_hal_target_devices = [ + f"--iree-hal-target-device={self.iree_hal_target_device}" + ] compile_args += iree_hal_target_devices if hal_dump_path: compile_args += [