Skip to content

Commit

Permalink
Enable the use of self-hosted LLMs (#36)
Browse files Browse the repository at this point in the history
* do not set defaults on cli args

* allow for empty api key

* decouple key check and model ret check

* fix logic for params precedence

* less verbosity by default when everything's fine

* add missing curl deps

* init clients with api endpoint

* use api endpoint when listing avaialable models

* Add a dry-run option to exec

* enable setting role

* api key cannot be empty

* status subcommand fails with verification.sh test.

Signed-off-by: Tomoya Fujita <[email protected]>

---------

Signed-off-by: Tomoya Fujita <[email protected]>
Co-authored-by: Tomoya Fujita <[email protected]>
  • Loading branch information
artivis and fujitatomoya committed Aug 8, 2024
1 parent 2309dae commit ad04c11
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 25 deletions.
1 change: 1 addition & 0 deletions package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

<author email="[email protected]">Tomoya Fujita</author>

<exec_depend>curl</exec_depend>
<exec_depend>python3-openai-pip</exec_depend>
<exec_depend>ros2cli</exec_depend>

Expand Down
4 changes: 2 additions & 2 deletions ros2ai/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
def add_global_arguments(parser):
# add global arguments
parser.add_argument(
'-m', '--model', metavar='<model>', type=str, default=constants.ROS_OPENAI_DEFAULT_MODEL,
'-m', '--model', metavar='<model>', type=str, default=None,
help=f'Set OpenAI API model (default %(default)s) or '
f'use {constants.ROS_OPENAI_MODEL_NAME_ENV_VAR} environment variable. (argument prevails)')
parser.add_argument(
'-u', '--url', metavar='<url>', type=str, default=constants.ROS_OPENAI_DEFAULT_ENDPOINT,
'-u', '--url', metavar='<url>', type=str, default=None,
help='Set OpenAI API endpoint URL (default %(default)s) or '
f'use {constants.ROS_OPENAI_ENDPOINT_ENV_VAR} environment variable. (argument prevails)')
parser.add_argument(
Expand Down
19 changes: 9 additions & 10 deletions ros2ai/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,8 @@ def get_api_key() -> str:
:return: string of OpenAI API Key.
:raises: if OPENAI_API_KEY is not set.
"""
key_name = os.environ.get(constants.ROS_OPENAI_API_KEY_ENV_VAR)
if not key_name:
raise EnvironmentError(
f"'{constants.ROS_OPENAI_API_KEY_ENV_VAR}' environment variable is not set'"
)
else:
return key_name
key_name = os.environ.get(constants.ROS_OPENAI_API_KEY_ENV_VAR, "None")
return key_name

def get_ai_model() -> str:
"""
Expand Down Expand Up @@ -78,6 +73,9 @@ def get_temperature() -> float:
else:
return float(temperature)

def get_role_system(default_role_system: str = None) -> str:
return os.environ.get(constants.ROLE_SYSTEM_ENV_VAR, default_role_system)

class OpenAiConfig:
"""
Collect all OpenAI API related configuration from user setting as key-value pair.
Expand All @@ -90,12 +88,12 @@ def __init__(self, args):

# ai model is optional, command line argument prevails
self.config_pair['api_model'] = get_ai_model()
if args.model != constants.ROS_OPENAI_DEFAULT_MODEL:
if args.model and args.model != self.config_pair['api_model']:
self.config_pair['api_model'] = args.model

# api endpoint is optional, command line argument prevails
self.config_pair['api_endpoint'] = get_endpoint_url()
if args.url != constants.ROS_OPENAI_DEFAULT_MODEL:
if args.url and args.url != self.config_pair['api_endpoint']:
self.config_pair['api_endpoint'] = args.url

# api token is optional, only available via command line argument
Expand Down Expand Up @@ -127,7 +125,8 @@ def display_all(self):
def is_api_key_valid(self):
# Validate api key, model and endpoint to post the API
client = OpenAI(
api_key=self.get_value('api_key')
api_key=self.get_value('api_key'),
base_url=self.get_value('api_endpoint')
)
try:
completion = client.chat.completions.create(
Expand Down
1 change: 1 addition & 0 deletions ros2ai/api/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ROLE_SYSTEM_EXEC_DEFAULT = \
'You are a Robot Operating System 2 (as known as ROS2) {} distribution command line executor, ' \
'provides executable command string only without any comments or code blocks.'
ROLE_SYSTEM_ENV_VAR = 'OPENAI_ROLE_SYSTEM'

# Temperature controls the consistency for behavior. (range 0.0 - 2.0)
# The lower the temperature is, the more deterministic behavior that OpenAI does.
Expand Down
3 changes: 2 additions & 1 deletion ros2ai/api/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def __init__(self, args):
super().__init__(args)

self.client_ = OpenAI(
api_key=self.get_value('api_key')
api_key=self.get_value('api_key'),
base_url=self.get_value('api_endpoint')
)
self.completion_ = None
self.stream_ = True
Expand Down
28 changes: 23 additions & 5 deletions ros2ai/verb/exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

from ros2ai.api import add_global_arguments
from ros2ai.api.utils import run_executable, truncate_before_substring
from ros2ai.api.config import get_role_system
from ros2ai.api.constants import ROLE_SYSTEM_EXEC_DEFAULT
from ros2ai.api.openai import ChatCompletionClient, ChatCompletionParameters
from ros2ai.api.utils import get_ros_distro
from ros2ai.api.utils import get_ros_distro, run_executable, truncate_before_substring
from ros2ai.verb import VerbExtension

import ros2ai.api.constants as constants

class ExecVerb(VerbExtension):
"""Execute a ROS 2 CLI based on user query to OpenAI API."""
Expand All @@ -36,6 +36,17 @@ def add_arguments(self, parser, cli_name):
'--debug',
action='store_true',
help='Prints detailed information and behavior of OpenAI (debug use only)')
parser.add_argument(
'--dry-run',
action='store_true',
help='Prints the command instead of executing it.')
parser.add_argument(
'-r',
'--role',
metavar='<role>',
type=str,
default=None,
help='Define the prompt\'s system role.')

def main(self, *, args):
request = ''
Expand All @@ -47,7 +58,10 @@ def main(self, *, args):
distro = get_ros_distro()
if distro is None:
distro = 'rolling' # fallback to rolling in default
system_role = constants.ROLE_SYSTEM_EXEC_DEFAULT.format(distro)
system_role = get_role_system(default_role_system=ROLE_SYSTEM_EXEC_DEFAULT)
if args.role and args.role != system_role:
system_role = args.role
system_role = system_role.format(distro)
user_request = [
{"role": "system", "content": f"{system_role}"},
{"role": "user", "content": f"{request}"}
Expand All @@ -58,6 +72,10 @@ def main(self, *, args):
client.call(completion_params)
if (args.debug is True):
client.print_all()
print(f"System role:\n{system_role}")
command_str = truncate_before_substring(
original = client.get_result(), substring = 'ros2')
run_executable(command = command_str)
if not args.dry_run:
run_executable(command = command_str)
else:
print(f"Command: '{command_str}'")
16 changes: 14 additions & 2 deletions ros2ai/verb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.

from ros2ai.api import add_global_arguments
from ros2ai.api.config import get_role_system
from ros2ai.api.constants import ROLE_SYSTEM_QUERY_DEFAULT
from ros2ai.api.openai import ChatCompletionClient, ChatCompletionParameters
from ros2ai.api.utils import get_ros_distro
from ros2ai.verb import VerbExtension

import ros2ai.api.constants as constants

class QueryVerb(VerbExtension):
"""Query a single completion to OpenAI API."""
Expand All @@ -40,6 +41,13 @@ def add_arguments(self, parser, cli_name):
'--verbose',
action='store_true',
help='Prints detailed response information (only available with nostream option)')
parser.add_argument(
'-r',
'--role',
metavar='<role>',
type=str,
default=None,
help='Define the prompt\'s system role.')

def main(self, *, args):
sentence = ''
Expand All @@ -51,7 +59,10 @@ def main(self, *, args):
distro = get_ros_distro()
if distro is None:
distro = 'rolling' # fallback to rolling in default
system_role = constants.ROLE_SYSTEM_QUERY_DEFAULT.format(distro)
system_role = get_role_system(default_role_system=ROLE_SYSTEM_QUERY_DEFAULT)
if args.role and args.role != system_role:
system_role = args.role
system_role = system_role.format(distro)
user_messages = [
{"role": "system", "content": f"{system_role}"},
{"role": "user", "content": f"{sentence}"}
Expand All @@ -63,5 +74,6 @@ def main(self, *, args):
client.call(completion_params)
if (args.verbose is True and args.nostream is True):
client.print_all()
print(f"System role:\n{system_role}")
else:
client.print_result()
16 changes: 12 additions & 4 deletions ros2ai/verb/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,25 @@ def main(self, *, args):
# try to call OpenAI API with user configured setting
is_valid = openai_config.is_api_key_valid()

if is_valid:
if args.verbose:
print("[SUCCESS] Valid OpenAI API key.")
else:
print("[FAILURE] Invalid OpenAI API key.")
return 1

# try to list the all models via user configured api key
headers = {"Authorization": "Bearer " + openai_config.get_value('api_key')}
can_get_models, model_list = curl_get_request(
"https://api.openai.com/v1/models",
openai_config.get_value('api_endpoint') + "/models",
headers
)

if is_valid and can_get_models:
print("[SUCCESS] Valid OpenAI API key.")
if can_get_models:
if args.verbose:
print("[SUCCESS] Retrieved list of models.")
else:
print("[FAILURE] Invalid OpenAI API key.")
print("[FAILURE] Could not retrieved list of models.")
return 1

if (args.list is True):
Expand Down
2 changes: 1 addition & 1 deletion scripts/verification.sh
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ function verify_ros2ai() {
echo "[${FUNCNAME[0]}]: verifying ros2ai."
# execute all commands in the list
for command in "${command_list[@]}"; do
#echo "----- $command"
echo "----- $command"
eval $command
done
echo "----- all ros2ai commands return successfully!!! -----"
Expand Down

0 comments on commit ad04c11

Please sign in to comment.