Skip to content

Commit

Permalink
enable setting role
Browse files Browse the repository at this point in the history
  • Loading branch information
artivis committed Aug 1, 2024
1 parent db87a7a commit ab90e2c
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 6 deletions.
3 changes: 3 additions & 0 deletions ros2ai/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def get_temperature() -> float:
else:
return float(temperature)

def get_role_system(default_role_system: str = None) -> str:
return os.environ.get(constants.ROLE_SYSTEM_ENV_VAR, default_role_system)

class OpenAiConfig:
"""
Collect all OpenAI API related configuration from user setting as key-value pair.
Expand Down
1 change: 1 addition & 0 deletions ros2ai/api/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ROLE_SYSTEM_EXEC_DEFAULT = \
'You are a Robot Operating System 2 (as known as ROS2) {} distribution command line executor, ' \
'provides executable command string only without any comments or code blocks.'
ROLE_SYSTEM_ENV_VAR = 'OPENAI_ROLE_SYSTEM'

# Temperature controls the consistency for behavior. (range 0.0 - 2.0)
# The lower the temperature is, the more deterministic behavior that OpenAI does.
Expand Down
19 changes: 15 additions & 4 deletions ros2ai/verb/exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

from ros2ai.api import add_global_arguments
from ros2ai.api.utils import run_executable, truncate_before_substring
from ros2ai.api.config import get_role_system
from ros2ai.api.constants import ROLE_SYSTEM_EXEC_DEFAULT
from ros2ai.api.openai import ChatCompletionClient, ChatCompletionParameters
from ros2ai.api.utils import get_ros_distro
from ros2ai.api.utils import get_ros_distro, run_executable, truncate_before_substring
from ros2ai.verb import VerbExtension

import ros2ai.api.constants as constants

class ExecVerb(VerbExtension):
"""Execute a ROS 2 CLI based on user query to OpenAI API."""
Expand All @@ -40,6 +40,13 @@ def add_arguments(self, parser, cli_name):
'--dry-run',
action='store_true',
help='Prints the command instead of executing it.')
parser.add_argument(
'-r',
'--role',
metavar='<role>',
type=str,
default=None,
help='Define the prompt\'s system role.')

def main(self, *, args):
request = ''
Expand All @@ -51,7 +58,10 @@ def main(self, *, args):
distro = get_ros_distro()
if distro is None:
distro = 'rolling' # fallback to rolling in default
system_role = constants.ROLE_SYSTEM_EXEC_DEFAULT.format(distro)
system_role = get_role_system(default_role_system=ROLE_SYSTEM_EXEC_DEFAULT)
if args.role and args.role != system_role:
system_role = args.role
system_role = system_role.format(distro)
user_request = [
{"role": "system", "content": f"{system_role}"},
{"role": "user", "content": f"{request}"}
Expand All @@ -62,6 +72,7 @@ def main(self, *, args):
client.call(completion_params)
if (args.debug is True):
client.print_all()
print(f"System role:\n{system_role}")
command_str = truncate_before_substring(
original = client.get_result(), substring = 'ros2')
if not args.dry_run:
Expand Down
16 changes: 14 additions & 2 deletions ros2ai/verb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.

from ros2ai.api import add_global_arguments
from ros2ai.api.config import get_role_system
from ros2ai.api.constants import ROLE_SYSTEM_QUERY_DEFAULT
from ros2ai.api.openai import ChatCompletionClient, ChatCompletionParameters
from ros2ai.api.utils import get_ros_distro
from ros2ai.verb import VerbExtension

import ros2ai.api.constants as constants

class QueryVerb(VerbExtension):
"""Query a single completion to OpenAI API."""
Expand All @@ -40,6 +41,13 @@ def add_arguments(self, parser, cli_name):
'--verbose',
action='store_true',
help='Prints detailed response information (only available with nostream option)')
parser.add_argument(
'-r',
'--role',
metavar='<role>',
type=str,
default=None,
help='Define the prompt\'s system role.')

def main(self, *, args):
sentence = ''
Expand All @@ -51,7 +59,10 @@ def main(self, *, args):
distro = get_ros_distro()
if distro is None:
distro = 'rolling' # fallback to rolling in default
system_role = constants.ROLE_SYSTEM_QUERY_DEFAULT.format(distro)
system_role = get_role_system(default_role_system=ROLE_SYSTEM_QUERY_DEFAULT)
if args.role and args.role != system_role:
system_role = args.role
system_role = system_role.format(distro)
user_messages = [
{"role": "system", "content": f"{system_role}"},
{"role": "user", "content": f"{sentence}"}
Expand All @@ -63,5 +74,6 @@ def main(self, *, args):
client.call(completion_params)
if (args.verbose is True and args.nostream is True):
client.print_all()
print(f"System role:\n{system_role}")
else:
client.print_result()

0 comments on commit ab90e2c

Please sign in to comment.