Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable the use of self-hosted LLMs #36

Merged
merged 12 commits into from
Aug 8, 2024
1 change: 1 addition & 0 deletions package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

<author email="[email protected]">Tomoya Fujita</author>

<exec_depend>curl</exec_depend>
<exec_depend>python3-openai-pip</exec_depend>
<exec_depend>ros2cli</exec_depend>

Expand Down
4 changes: 2 additions & 2 deletions ros2ai/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
def add_global_arguments(parser):
# add global arguments
parser.add_argument(
'-m', '--model', metavar='<model>', type=str, default=constants.ROS_OPENAI_DEFAULT_MODEL,
'-m', '--model', metavar='<model>', type=str, default=None,
help=f'Set OpenAI API model (default %(default)s) or '
f'use {constants.ROS_OPENAI_MODEL_NAME_ENV_VAR} environment variable. (argument prevails)')
parser.add_argument(
'-u', '--url', metavar='<url>', type=str, default=constants.ROS_OPENAI_DEFAULT_ENDPOINT,
'-u', '--url', metavar='<url>', type=str, default=None,
help='Set OpenAI API endpoint URL (default %(default)s) or '
f'use {constants.ROS_OPENAI_ENDPOINT_ENV_VAR} environment variable. (argument prevails)')
parser.add_argument(
Expand Down
19 changes: 9 additions & 10 deletions ros2ai/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,8 @@ def get_api_key() -> str:
:return: string of OpenAI API Key.
:raises: if OPENAI_API_KEY is not set.
"""
key_name = os.environ.get(constants.ROS_OPENAI_API_KEY_ENV_VAR)
if not key_name:
raise EnvironmentError(
f"'{constants.ROS_OPENAI_API_KEY_ENV_VAR}' environment variable is not set'"
)
else:
return key_name
key_name = os.environ.get(constants.ROS_OPENAI_API_KEY_ENV_VAR, "None")
return key_name
fujitatomoya marked this conversation as resolved.
Show resolved Hide resolved

def get_ai_model() -> str:
"""
Expand Down Expand Up @@ -78,6 +73,9 @@ def get_temperature() -> float:
else:
return float(temperature)

def get_role_system(default_role_system: str = None) -> str:
return os.environ.get(constants.ROLE_SYSTEM_ENV_VAR, default_role_system)

class OpenAiConfig:
"""
Collect all OpenAI API related configuration from user setting as key-value pair.
Expand All @@ -90,12 +88,12 @@ def __init__(self, args):

# ai model is optional, command line argument prevails
self.config_pair['api_model'] = get_ai_model()
if args.model != constants.ROS_OPENAI_DEFAULT_MODEL:
if args.model and args.model != self.config_pair['api_model']:
self.config_pair['api_model'] = args.model

# api endpoint is optional, command line argument prevails
self.config_pair['api_endpoint'] = get_endpoint_url()
if args.url != constants.ROS_OPENAI_DEFAULT_MODEL:
if args.url and args.url != self.config_pair['api_endpoint']:
self.config_pair['api_endpoint'] = args.url

# api token is optional, only available via command line argument
Expand Down Expand Up @@ -127,7 +125,8 @@ def display_all(self):
def is_api_key_valid(self):
# Validate api key, model and endpoint to post the API
client = OpenAI(
api_key=self.get_value('api_key')
api_key=self.get_value('api_key'),
base_url=self.get_value('api_endpoint')
fujitatomoya marked this conversation as resolved.
Show resolved Hide resolved
)
try:
completion = client.chat.completions.create(
Expand Down
1 change: 1 addition & 0 deletions ros2ai/api/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ROLE_SYSTEM_EXEC_DEFAULT = \
'You are a Robot Operating System 2 (as known as ROS2) {} distribution command line executor, ' \
'provides executable command string only without any comments or code blocks.'
ROLE_SYSTEM_ENV_VAR = 'OPENAI_ROLE_SYSTEM'

# Temperature controls the consistency for behavior. (range 0.0 - 2.0)
# The lower the temperature is, the more deterministic behavior that OpenAI does.
Expand Down
3 changes: 2 additions & 1 deletion ros2ai/api/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def __init__(self, args):
super().__init__(args)

self.client_ = OpenAI(
api_key=self.get_value('api_key')
api_key=self.get_value('api_key'),
base_url=self.get_value('api_endpoint')
)
self.completion_ = None
self.stream_ = True
Expand Down
28 changes: 23 additions & 5 deletions ros2ai/verb/exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

from ros2ai.api import add_global_arguments
from ros2ai.api.utils import run_executable, truncate_before_substring
from ros2ai.api.config import get_role_system
from ros2ai.api.constants import ROLE_SYSTEM_EXEC_DEFAULT
from ros2ai.api.openai import ChatCompletionClient, ChatCompletionParameters
from ros2ai.api.utils import get_ros_distro
from ros2ai.api.utils import get_ros_distro, run_executable, truncate_before_substring
from ros2ai.verb import VerbExtension

import ros2ai.api.constants as constants

class ExecVerb(VerbExtension):
"""Execute a ROS 2 CLI based on user query to OpenAI API."""
Expand All @@ -36,6 +36,17 @@ def add_arguments(self, parser, cli_name):
'--debug',
action='store_true',
help='Prints detailed information and behavior of OpenAI (debug use only)')
parser.add_argument(
'--dry-run',
action='store_true',
help='Prints the command instead of executing it.')
parser.add_argument(
'-r',
'--role',
metavar='<role>',
type=str,
default=None,
help='Define the prompt\'s system role.')

def main(self, *, args):
request = ''
Expand All @@ -47,7 +58,10 @@ def main(self, *, args):
distro = get_ros_distro()
if distro is None:
distro = 'rolling' # fallback to rolling in default
system_role = constants.ROLE_SYSTEM_EXEC_DEFAULT.format(distro)
system_role = get_role_system(default_role_system=ROLE_SYSTEM_EXEC_DEFAULT)
if args.role and args.role != system_role:
system_role = args.role
system_role = system_role.format(distro)
user_request = [
{"role": "system", "content": f"{system_role}"},
{"role": "user", "content": f"{request}"}
Expand All @@ -58,6 +72,10 @@ def main(self, *, args):
client.call(completion_params)
if (args.debug is True):
client.print_all()
print(f"System role:\n{system_role}")
command_str = truncate_before_substring(
original = client.get_result(), substring = 'ros2')
run_executable(command = command_str)
if not args.dry_run:
run_executable(command = command_str)
else:
print(f"Command: '{command_str}'")
16 changes: 14 additions & 2 deletions ros2ai/verb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.

from ros2ai.api import add_global_arguments
from ros2ai.api.config import get_role_system
from ros2ai.api.constants import ROLE_SYSTEM_QUERY_DEFAULT
from ros2ai.api.openai import ChatCompletionClient, ChatCompletionParameters
from ros2ai.api.utils import get_ros_distro
from ros2ai.verb import VerbExtension

import ros2ai.api.constants as constants

class QueryVerb(VerbExtension):
"""Query a single completion to OpenAI API."""
Expand All @@ -40,6 +41,13 @@ def add_arguments(self, parser, cli_name):
'--verbose',
action='store_true',
help='Prints detailed response information (only available with nostream option)')
parser.add_argument(
'-r',
'--role',
metavar='<role>',
type=str,
default=None,
help='Define the prompt\'s system role.')

def main(self, *, args):
sentence = ''
Expand All @@ -51,7 +59,10 @@ def main(self, *, args):
distro = get_ros_distro()
if distro is None:
distro = 'rolling' # fallback to rolling in default
system_role = constants.ROLE_SYSTEM_QUERY_DEFAULT.format(distro)
system_role = get_role_system(default_role_system=ROLE_SYSTEM_QUERY_DEFAULT)
if args.role and args.role != system_role:
system_role = args.role
system_role = system_role.format(distro)
user_messages = [
{"role": "system", "content": f"{system_role}"},
{"role": "user", "content": f"{sentence}"}
Expand All @@ -63,5 +74,6 @@ def main(self, *, args):
client.call(completion_params)
if (args.verbose is True and args.nostream is True):
client.print_all()
print(f"System role:\n{system_role}")
else:
client.print_result()
16 changes: 12 additions & 4 deletions ros2ai/verb/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,25 @@ def main(self, *, args):
# try to call OpenAI API with user configured setting
is_valid = openai_config.is_api_key_valid()

if is_valid:
if args.verbose:
print("[SUCCESS] Valid OpenAI API key.")
else:
print("[FAILURE] Invalid OpenAI API key.")
return 1

# try to list the all models via user configured api key
headers = {"Authorization": "Bearer " + openai_config.get_value('api_key')}
can_get_models, model_list = curl_get_request(
"https://api.openai.com/v1/models",
openai_config.get_value('api_endpoint') + "/models",
headers
)

if is_valid and can_get_models:
print("[SUCCESS] Valid OpenAI API key.")
if can_get_models:
if args.verbose:
print("[SUCCESS] Retrieved list of models.")
else:
print("[FAILURE] Invalid OpenAI API key.")
print("[FAILURE] Could not retrieved list of models.")
return 1

if (args.list is True):
Expand Down
2 changes: 1 addition & 1 deletion scripts/verification.sh
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ function verify_ros2ai() {
echo "[${FUNCNAME[0]}]: verifying ros2ai."
# execute all commands in the list
for command in "${command_list[@]}"; do
#echo "----- $command"
echo "----- $command"
eval $command
done
echo "----- all ros2ai commands return successfully!!! -----"
Expand Down
Loading