From ad04c1119ef783041750760428414759e65aab75 Mon Sep 17 00:00:00 2001 From: Jeremie Deray Date: Thu, 8 Aug 2024 23:37:33 +0200 Subject: [PATCH] Enable the use of self-hosted LLMs (#36) * do not set defaults on cli args * allow for empty api key * decouple key check and model ret check * fix logic for params precedence * less verbosity by default when everything's fine * add missing curl deps * init clients with api endpoint * use api endpoint when listing avaialable models * Add a dry-run option to exec * enable setting role * api key cannot be empty * status subcommand fails with verification.sh test. Signed-off-by: Tomoya Fujita --------- Signed-off-by: Tomoya Fujita Co-authored-by: Tomoya Fujita --- package.xml | 1 + ros2ai/api/__init__.py | 4 ++-- ros2ai/api/config.py | 19 +++++++++---------- ros2ai/api/constants.py | 1 + ros2ai/api/openai.py | 3 ++- ros2ai/verb/exec.py | 28 +++++++++++++++++++++++----- ros2ai/verb/query.py | 16 ++++++++++++++-- ros2ai/verb/status.py | 16 ++++++++++++---- scripts/verification.sh | 2 +- 9 files changed, 65 insertions(+), 25 deletions(-) diff --git a/package.xml b/package.xml index b3b3808..16c54ca 100644 --- a/package.xml +++ b/package.xml @@ -13,6 +13,7 @@ Tomoya Fujita + curl python3-openai-pip ros2cli diff --git a/ros2ai/api/__init__.py b/ros2ai/api/__init__.py index e64ff12..6dbd6fe 100644 --- a/ros2ai/api/__init__.py +++ b/ros2ai/api/__init__.py @@ -22,11 +22,11 @@ def add_global_arguments(parser): # add global arguments parser.add_argument( - '-m', '--model', metavar='', type=str, default=constants.ROS_OPENAI_DEFAULT_MODEL, + '-m', '--model', metavar='', type=str, default=None, help=f'Set OpenAI API model (default %(default)s) or ' f'use {constants.ROS_OPENAI_MODEL_NAME_ENV_VAR} environment variable. (argument prevails)') parser.add_argument( - '-u', '--url', metavar='', type=str, default=constants.ROS_OPENAI_DEFAULT_ENDPOINT, + '-u', '--url', metavar='', type=str, default=None, help='Set OpenAI API endpoint URL (default %(default)s) or ' f'use {constants.ROS_OPENAI_ENDPOINT_ENV_VAR} environment variable. (argument prevails)') parser.add_argument( diff --git a/ros2ai/api/config.py b/ros2ai/api/config.py index 92beed3..e2be592 100644 --- a/ros2ai/api/config.py +++ b/ros2ai/api/config.py @@ -28,13 +28,8 @@ def get_api_key() -> str: :return: string of OpenAI API Key. :raises: if OPENAI_API_KEY is not set. """ - key_name = os.environ.get(constants.ROS_OPENAI_API_KEY_ENV_VAR) - if not key_name: - raise EnvironmentError( - f"'{constants.ROS_OPENAI_API_KEY_ENV_VAR}' environment variable is not set'" - ) - else: - return key_name + key_name = os.environ.get(constants.ROS_OPENAI_API_KEY_ENV_VAR, "None") + return key_name def get_ai_model() -> str: """ @@ -78,6 +73,9 @@ def get_temperature() -> float: else: return float(temperature) +def get_role_system(default_role_system: str = None) -> str: + return os.environ.get(constants.ROLE_SYSTEM_ENV_VAR, default_role_system) + class OpenAiConfig: """ Collect all OpenAI API related configuration from user setting as key-value pair. @@ -90,12 +88,12 @@ def __init__(self, args): # ai model is optional, command line argument prevails self.config_pair['api_model'] = get_ai_model() - if args.model != constants.ROS_OPENAI_DEFAULT_MODEL: + if args.model and args.model != self.config_pair['api_model']: self.config_pair['api_model'] = args.model # api endpoint is optional, command line argument prevails self.config_pair['api_endpoint'] = get_endpoint_url() - if args.url != constants.ROS_OPENAI_DEFAULT_MODEL: + if args.url and args.url != self.config_pair['api_endpoint']: self.config_pair['api_endpoint'] = args.url # api token is optional, only available via command line argument @@ -127,7 +125,8 @@ def display_all(self): def is_api_key_valid(self): # Validate api key, model and endpoint to post the API client = OpenAI( - api_key=self.get_value('api_key') + api_key=self.get_value('api_key'), + base_url=self.get_value('api_endpoint') ) try: completion = client.chat.completions.create( diff --git a/ros2ai/api/constants.py b/ros2ai/api/constants.py index cd7c593..9b3e606 100644 --- a/ros2ai/api/constants.py +++ b/ros2ai/api/constants.py @@ -32,6 +32,7 @@ ROLE_SYSTEM_EXEC_DEFAULT = \ 'You are a Robot Operating System 2 (as known as ROS2) {} distribution command line executor, ' \ 'provides executable command string only without any comments or code blocks.' +ROLE_SYSTEM_ENV_VAR = 'OPENAI_ROLE_SYSTEM' # Temperature controls the consistency for behavior. (range 0.0 - 2.0) # The lower the temperature is, the more deterministic behavior that OpenAI does. diff --git a/ros2ai/api/openai.py b/ros2ai/api/openai.py index a7db116..56c0f9b 100644 --- a/ros2ai/api/openai.py +++ b/ros2ai/api/openai.py @@ -26,7 +26,8 @@ def __init__(self, args): super().__init__(args) self.client_ = OpenAI( - api_key=self.get_value('api_key') + api_key=self.get_value('api_key'), + base_url=self.get_value('api_endpoint') ) self.completion_ = None self.stream_ = True diff --git a/ros2ai/verb/exec.py b/ros2ai/verb/exec.py index 0fac3b5..b2c4834 100644 --- a/ros2ai/verb/exec.py +++ b/ros2ai/verb/exec.py @@ -13,12 +13,12 @@ # limitations under the License. from ros2ai.api import add_global_arguments -from ros2ai.api.utils import run_executable, truncate_before_substring +from ros2ai.api.config import get_role_system +from ros2ai.api.constants import ROLE_SYSTEM_EXEC_DEFAULT from ros2ai.api.openai import ChatCompletionClient, ChatCompletionParameters -from ros2ai.api.utils import get_ros_distro +from ros2ai.api.utils import get_ros_distro, run_executable, truncate_before_substring from ros2ai.verb import VerbExtension -import ros2ai.api.constants as constants class ExecVerb(VerbExtension): """Execute a ROS 2 CLI based on user query to OpenAI API.""" @@ -36,6 +36,17 @@ def add_arguments(self, parser, cli_name): '--debug', action='store_true', help='Prints detailed information and behavior of OpenAI (debug use only)') + parser.add_argument( + '--dry-run', + action='store_true', + help='Prints the command instead of executing it.') + parser.add_argument( + '-r', + '--role', + metavar='', + type=str, + default=None, + help='Define the prompt\'s system role.') def main(self, *, args): request = '' @@ -47,7 +58,10 @@ def main(self, *, args): distro = get_ros_distro() if distro is None: distro = 'rolling' # fallback to rolling in default - system_role = constants.ROLE_SYSTEM_EXEC_DEFAULT.format(distro) + system_role = get_role_system(default_role_system=ROLE_SYSTEM_EXEC_DEFAULT) + if args.role and args.role != system_role: + system_role = args.role + system_role = system_role.format(distro) user_request = [ {"role": "system", "content": f"{system_role}"}, {"role": "user", "content": f"{request}"} @@ -58,6 +72,10 @@ def main(self, *, args): client.call(completion_params) if (args.debug is True): client.print_all() + print(f"System role:\n{system_role}") command_str = truncate_before_substring( original = client.get_result(), substring = 'ros2') - run_executable(command = command_str) + if not args.dry_run: + run_executable(command = command_str) + else: + print(f"Command: '{command_str}'") diff --git a/ros2ai/verb/query.py b/ros2ai/verb/query.py index 05336c3..1e27012 100644 --- a/ros2ai/verb/query.py +++ b/ros2ai/verb/query.py @@ -13,11 +13,12 @@ # limitations under the License. from ros2ai.api import add_global_arguments +from ros2ai.api.config import get_role_system +from ros2ai.api.constants import ROLE_SYSTEM_QUERY_DEFAULT from ros2ai.api.openai import ChatCompletionClient, ChatCompletionParameters from ros2ai.api.utils import get_ros_distro from ros2ai.verb import VerbExtension -import ros2ai.api.constants as constants class QueryVerb(VerbExtension): """Query a single completion to OpenAI API.""" @@ -40,6 +41,13 @@ def add_arguments(self, parser, cli_name): '--verbose', action='store_true', help='Prints detailed response information (only available with nostream option)') + parser.add_argument( + '-r', + '--role', + metavar='', + type=str, + default=None, + help='Define the prompt\'s system role.') def main(self, *, args): sentence = '' @@ -51,7 +59,10 @@ def main(self, *, args): distro = get_ros_distro() if distro is None: distro = 'rolling' # fallback to rolling in default - system_role = constants.ROLE_SYSTEM_QUERY_DEFAULT.format(distro) + system_role = get_role_system(default_role_system=ROLE_SYSTEM_QUERY_DEFAULT) + if args.role and args.role != system_role: + system_role = args.role + system_role = system_role.format(distro) user_messages = [ {"role": "system", "content": f"{system_role}"}, {"role": "user", "content": f"{sentence}"} @@ -63,5 +74,6 @@ def main(self, *, args): client.call(completion_params) if (args.verbose is True and args.nostream is True): client.print_all() + print(f"System role:\n{system_role}") else: client.print_result() diff --git a/ros2ai/verb/status.py b/ros2ai/verb/status.py index 681c599..a4b90ef 100644 --- a/ros2ai/verb/status.py +++ b/ros2ai/verb/status.py @@ -45,17 +45,25 @@ def main(self, *, args): # try to call OpenAI API with user configured setting is_valid = openai_config.is_api_key_valid() + if is_valid: + if args.verbose: + print("[SUCCESS] Valid OpenAI API key.") + else: + print("[FAILURE] Invalid OpenAI API key.") + return 1 + # try to list the all models via user configured api key headers = {"Authorization": "Bearer " + openai_config.get_value('api_key')} can_get_models, model_list = curl_get_request( - "https://api.openai.com/v1/models", + openai_config.get_value('api_endpoint') + "/models", headers ) - if is_valid and can_get_models: - print("[SUCCESS] Valid OpenAI API key.") + if can_get_models: + if args.verbose: + print("[SUCCESS] Retrieved list of models.") else: - print("[FAILURE] Invalid OpenAI API key.") + print("[FAILURE] Could not retrieved list of models.") return 1 if (args.list is True): diff --git a/scripts/verification.sh b/scripts/verification.sh index 9620ffe..67603e1 100755 --- a/scripts/verification.sh +++ b/scripts/verification.sh @@ -62,7 +62,7 @@ function verify_ros2ai() { echo "[${FUNCNAME[0]}]: verifying ros2ai." # execute all commands in the list for command in "${command_list[@]}"; do - #echo "----- $command" + echo "----- $command" eval $command done echo "----- all ros2ai commands return successfully!!! -----"