Skip to content

Commit

Permalink
Use iree_hal_target_device flag to compile
Browse files Browse the repository at this point in the history
  • Loading branch information
archana-ramalingam committed Jan 6, 2025
1 parent 69997e6 commit 4849c76
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions sharktank/sharktank/utils/export_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(
attention_kernel: str,
tensor_parallelism_size: int,
block_seq_stride: int,
iree_hal_target_device: Optional[str] = None,
iree_hal_target_device: str,
):
self.sharktank_dir = str(
Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent
Expand Down Expand Up @@ -220,11 +220,13 @@ def compile_to_vmfb(
]
if self.tensor_parallelism_size > 1:
iree_hal_target_devices = [
f"--iree-hal-target-device=hip[{i}]"
f"--iree-hal-target-device={self.iree_hal_target_device}[{i}]"
for i in range(self.tensor_parallelism_size)
]
else:
iree_hal_target_devices = ["--iree-hal-target-device=hip"]
iree_hal_target_devices = [
f"--iree-hal-target-device={self.iree_hal_target_device}"
]
compile_args += iree_hal_target_devices
if hal_dump_path:
compile_args += [
Expand Down

0 comments on commit 4849c76

Please sign in to comment.