diff --git a/scripts/tagger/interrogator.py b/scripts/tagger/interrogator.py index 13b256a..19407ae 100644 --- a/scripts/tagger/interrogator.py +++ b/scripts/tagger/interrogator.py @@ -4,6 +4,7 @@ import io from hashlib import sha256 import json +from platform import system from typing import Tuple, List, Dict, Callable from pandas import read_csv, read_json from PIL import Image, UnidentifiedImageError @@ -438,9 +439,13 @@ def load(self) -> None: # TODO: remove old package when the environment changes? from launch import is_installed, run_pip if not is_installed('onnxruntime'): + if system() == "Darwin": + package_name = "onnxruntime-silicon" + else: + package_name = "onnxruntime-gpu" package = os.environ.get( 'ONNXRUNTIME_PACKAGE', - 'onnxruntime-gpu' + package_name ) run_pip(f'install {package}', 'onnxruntime')