Skip to content

Commit

Permalink
Merge pull request #19 from OoriData/18-prompt-overriding
Browse files Browse the repository at this point in the history
18 prompt overriding
  • Loading branch information
uogbuji authored Aug 7, 2024
2 parents 225afca + c06b77f commit 66c5f1b
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 126 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,20 @@ Notable changes to Format based on [Keep a Changelog](https://keepachangelog.co
-->

## [0.4.2] - 20240807

### Added

- notes on how to override prompting

### Changed

- processing for function-calling system prompts

### Fixed

- server startup 😬

## [0.4.1] - 20240806

### Added
Expand Down
54 changes: 54 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,60 @@ async def query_sq_root(tmm):
asyncio.run(query_sq_root(toolio_mm))
```

# Tweaking prompts

Part of the process of getting an LLM to stick to a schema, or to call tools is to give it a system prompt to that effect. Toolio has built in prompt language for this purpose. We believe strongly in the design principle of separating natural language (e.g. prompts) from code, so the latyter is packaged into the `resource/language.toml` file, using [Word Loom](https://github.com/OoriData/OgbujiPT/wiki/Word-Loom:-A-format-for-managing-language-for-AI-LLMs-(including-prompts)) conventions.

You can of course override the built-in prompting.

## Overriding the tool-calling system prompt from the command line

```sh
echo 'What is the square root of 256?' > /tmp/llmprompt.txt
echo '{"tools": [{"type": "function","function": {"name": "square_root","description": "Get the square root of the given number","parameters": {"type": "object", "properties": {"square": {"type": "number", "description": "Number from which to find the square root"}},"required": ["square"]},"pyfunc": "math|sqrt"}}], "tool_choice": "auto"}' > /tmp/toolspec.json
toolio_request --apibase="http://localhost:8000" --prompt-file=/tmp/llmprompt.txt --tools-file=/tmp/toolspec.json --sysprompt="You are a helpful assistant with access to a tool that you may invoke if needed to answer the user's request. Please use the tool as applicable, even if you think you already know the answer. Give your final answer in Shakespearean English The tool is:
Tool"
```

## Overriding the tool-calling system prompt from the Python API

In order to override the system prompt from code, just se it in the initial chat message as the `system` role.

```py
import asyncio
from math import sqrt
from toolio.llm_helper import model_manager, extract_content

SQUARE_ROOT_METADATA = {'name': 'square_root', 'description': 'Get the square root of the given number',
'parameters': {'type': 'object', 'properties': {
'square': {'type': 'number',
'description': 'Number from which to find the square root'}},
'required': ['square']}}
toolio_mm = model_manager('mlx-community/Hermes-2-Theta-Llama-3-8B-4bit',
tool_reg=[(sqrt, SQUARE_ROOT_METADATA)], trace=True)

# System prompt will be used to direct the LLM's tool-calling
SYSPROMPT = '''You are a tutor from Elizabethan England, with access a tool that you may invoke if needed to answer the user's request. Please use the tool as applicable, even if you think you already know the answer. Remember to give your final answer in Elizabethan English The tool is:
Tool
'''

async def query_sq_root(tmm):
msgs = [
{'role': 'system', 'content': SYSPROMPT},
{'role': 'user', 'content': 'What is the square root of 256?'}
]
async for chunk in extract_content(tmm.complete_with_tools(msgs)):
print(chunk, end='')

asyncio.run(query_sq_root(toolio_mm))
```

In which case you can express a response such as:

```
Good sir or madam, the square root of 256 is indeed 16. Mayhap thou wouldst like to know more of this wondrous number? I am at thy service.
```

# More examples

See the `demo` directory.
Expand Down
2 changes: 1 addition & 1 deletion pylib/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# SPDX-License-Identifier: Apache-2.0
# ogbujipt.about

__version__ = '0.4.1'
__version__ = '0.4.2'
10 changes: 10 additions & 0 deletions pylib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# # toolio

from pathlib import Path # noqa: E402
from enum import Flag, auto

from ogbujipt import word_loom
from toolio import __about__
Expand All @@ -27,3 +28,12 @@ def obj_file_path_parent(obj):
HERE = obj_file_path_parent(lambda: 0)
with open(HERE / Path('resource/language.toml'), mode='rb') as fp:
LANG = word_loom.load(fp)


class model_flag(Flag):
NO_SYSTEM_ROLE = auto() # e.g. Gemma blows up if you use a system message role
USER_ASSISTANT_ALT = auto() # Model requires alternation of message roles user/assistant only
TOOL_RESPONSE = auto() # Model expects responses from tools via OpenAI API style messages


DEFAULT_FLAGS = model_flag(0)
21 changes: 12 additions & 9 deletions pylib/cli/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
from llm_structured_output.util.output import info, warning, debug

from toolio.schema_helper import Model
from toolio.llm_helper import enrich_chat_for_tools, DEFAULT_FLAGS, FLAGS_LOOKUP
from toolio.http_schematics import V1Function
from toolio.prompt_helper import enrich_chat_for_tools, process_tool_sysmsg
from toolio.llm_helper import DEFAULT_FLAGS, FLAGS_LOOKUP
from toolio.http_schematics import V1ChatCompletionsRequest, V1ResponseFormatType
from toolio.responder import (ToolCallStreamingResponder, ToolCallResponder,
ChatCompletionResponder, ChatCompletionStreamingResponder)
Expand Down Expand Up @@ -60,8 +62,6 @@ async def lifespan(app: FastAPI):
info(f'Model loaded in {(tdone - tstart)/1000000000.0:.3f}s. Type: {app.state.model.model.model_type}')
app.state.model_flags = FLAGS_LOOKUP.get(app.state.model.model.model_type, DEFAULT_FLAGS)
# Look into exposing control over methods & headers as well
app.add_middleware(CORSMiddleware, allow_origins=app_params['cors_origins'], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"])
yield
# Shutdown code here, if any

Expand Down Expand Up @@ -129,7 +129,7 @@ async def post_v1_chat_completions_impl(req_data: V1ChatCompletionsRequest):

model_name = app.state.params['model']
model_type = app.state.model.model.model_type
schema = None
schema = None # Steering for the LLM output (JSON schema)
if functions:
if req_data.sysmsg_leadin: # Caller provided sysmsg leadin via protocol
leadin = req_data.sysmsg_leadin
Expand All @@ -138,14 +138,15 @@ async def post_v1_chat_completions_impl(req_data: V1ChatCompletionsRequest):
del messages[0]
else: # Use default leadin
leadin = None
functions = [ (t.dictify() if isinstance(t, V1Function) else t) for t in functions ]
schema, tool_sysmsg = process_tool_sysmsg(functions, leadin=leadin)
if req_data.stream:
responder = ToolCallStreamingResponder(model_name, model_type, functions, app.state.model, leadin)
responder = ToolCallStreamingResponder(model_name, model_type, functions, schema, tool_sysmsg)
else:
responder = ToolCallResponder(model_name, model_type, functions, req_data.sysmsg_leadin)
responder = ToolCallResponder(model_name, model_type, schema, tool_sysmsg)
if not (req_data.tool_options and req_data.tool_options.no_prompt_steering):
enrich_chat_for_tools(messages, responder.tool_prompt, app.state.model_flags)
enrich_chat_for_tools(messages, tool_sysmsg, app.state.model_flags)
# import pprint; pprint.pprint(messages)
schema = responder.schema # Assemble a JSON schema to steer the LLM output
else:
if req_data.stream:
responder = ChatCompletionStreamingResponder(model_name, model_type)
Expand Down Expand Up @@ -211,7 +212,9 @@ async def post_v1_chat_completions_impl(req_data: V1ChatCompletionsRequest):
help='Origin to be permitted for CORS https://fastapi.tiangolo.com/tutorial/cors/')
def main(host, port, model, default_schema, default_schema_file, llmtemp, workers, cors_origin):
app_params.update(model=model, default_schema=default_schema, default_schema_fpath=default_schema_file,
llmtemp=llmtemp, cors_origins=list(cors_origin))
llmtemp=llmtemp)
app.add_middleware(CORSMiddleware, allow_origins=list(cors_origin), allow_credentials=True,
allow_methods=["*"], allow_headers=["*"])
workers = workers or None
# logger.info(f'Host has {NUM_CPUS} CPU cores')
uvicorn.run('toolio.cli.server:app', host=host, port=port, reload=False, workers=workers)
Expand Down
10 changes: 6 additions & 4 deletions pylib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,15 @@ def __init__(self, base_url=None, default_schema=None, flags=DEFAULT_FLAGS, tool
else:
self.register_tool(toolspec)

async def __call__(self, messages, req='chat/completions', schema=None, toolset=None, tool_choice=TOOL_CHOICE_AUTO,
apikey=None, max_trips=3, trip_timeout=90.0, **kwargs):
async def __call__(self, messages, req='chat/completions', schema=None, toolset=None, sysprompt=None,
tool_choice=TOOL_CHOICE_AUTO, apikey=None, max_trips=3, trip_timeout=90.0, **kwargs):
'''
Invoke the LLM with a completion request
Args:
messages (str) - Prompt in the form of list of messages to send ot the LLM for completion
messages (str) - Prompt in the form of list of messages to send ot the LLM for completion.
If you have a system prompt, and you are setting up to call tools, it will be updated with
the tool spec
trip_timeout (float) - timeout (in seconds) per LLM API request trip; defaults to 90s
Expand Down Expand Up @@ -139,7 +141,7 @@ async def __call__(self, messages, req='chat/completions', schema=None, toolset=
elif toolset or self._tool_schema_stanzs:
req_data['tool_choice'] = tool_choice
if req_tools and tool_choice == TOOL_CHOICE_NONE:
warnings.warn('Tools were provided, but tool_choise was set so they\'ll never be used')
warnings.warn('Tools were provided, but tool_choice was set to `none`, so they\'ll never be used')
# if tool_options: req_data['tool_options'] = tool_options
# for t in tools_list:
# self.register_tool(t['function']['name'], t['function'].get('pyfunc'))
Expand Down
62 changes: 7 additions & 55 deletions pylib/llm_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import sys
import json
from enum import Flag, auto
import importlib
import warnings

Expand All @@ -16,21 +15,15 @@

# from mlx_lm.models import olmo # Will say:: To run olmo install ai2-olmo: pip install ai2-olmo

from toolio import model_flag, DEFAULT_FLAGS
from toolio.util import check_callable
from toolio.schema_helper import Model
from toolio.http_schematics import V1ChatMessage, V1Function
from toolio.http_schematics import V1Function
from toolio.prompt_helper import set_tool_response, process_tool_sysmsg
from toolio.responder import (ToolCallStreamingResponder, ToolCallResponder,
ChatCompletionResponder, ChatCompletionStreamingResponder)


class model_flag(Flag):
NO_SYSTEM_ROLE = auto() # e.g. Gemma blows up if you use a system message role
USER_ASSISTANT_ALT = auto() # Model requires alternation of message roles user/assistant only
TOOL_RESPONSE = auto() # Model expects responses from tools via OpenAI API style messages


DEFAULT_FLAGS = model_flag(0)

TOOL_CHOICE_AUTO = 'auto'
TOOL_CHOICE_NONE = 'none'

Expand Down Expand Up @@ -281,12 +274,12 @@ async def _execute_tool_calls(self, response, req_tools):
return tool_responses

async def _completion_trip(self, messages, stream, req_tool_spec, max_tokens=128, temperature=0.1):
schema = None
req_tool_spec = [ (t.dictify() if isinstance(t, V1Function) else t) for t in req_tool_spec ]
schema, tool_sysmsg = process_tool_sysmsg(req_tool_spec, leadin=self.sysmsg_leadin)
if stream:
responder = ToolCallStreamingResponder(self.model, self.model_path, req_tool_spec, self.sysmsg_leadin)
responder = ToolCallStreamingResponder(self.model, self.model_path, req_tool_spec, schema, tool_sysmsg)
else:
responder = ToolCallResponder(self.model_path, self.model_type, req_tool_spec, self.sysmsg_leadin)
schema = responder.schema
responder = ToolCallResponder(self.model_path, self.model_type, schema, tool_sysmsg)
# Turn off prompt caching until we figure out https://github.com/OoriData/Toolio/issues/12
cache_prompt=False
async for resp in self._do_completion(messages, schema, responder, cache_prompt=cache_prompt,
Expand Down Expand Up @@ -325,44 +318,3 @@ async def extract_content(resp_stream):
content = chunk['choices'][0]['message'].get('content')
if content is not None:
yield content


def enrich_chat_for_tools(msgs, tool_prompt, model_flags):
'''
msgs - chat messages to augment
model_flags - flags indicating the expectations of the hosted LLM
'''
# Add prompting (system prompt, if permitted) instructing the LLM to use tools
if model_flag.NO_SYSTEM_ROLE in model_flags: # LLM supports system messages
msgs.insert(0, V1ChatMessage(role='system', content=tool_prompt))
elif model_flag.USER_ASSISTANT_ALT in model_flags: # LLM insists that user and assistant messages must alternate
msgs[0].content = msgs[0].content=tool_prompt + '\n\n' + msgs[0].content
else:
msgs.insert(0, V1ChatMessage(role='user', content=tool_prompt))


def set_tool_response(msgs, tool_call_id, tool_name, tool_result, model_flags=DEFAULT_FLAGS):
'''
msgs - chat messages to augment
tool_response - response generatded by selected tool
model_flags - flags indicating the expectations of the hosted LLM
'''
# XXX: model_flags = None ⇒ assistant-style tool response. Is this the default we want?
if model_flag.TOOL_RESPONSE in model_flags:
msgs.append({
'tool_call_id': tool_call_id,
'role': 'tool',
'name': tool_name,
'content': tool_result,
})
else:
# FIXME: Separate out natural language
tool_response_text = f'Result of the call to {tool_name}: {tool_result}'
if model_flag.USER_ASSISTANT_ALT in model_flags:
# If there is already an assistant msg from tool-calling, merge it
if msgs[-1]['role'] == 'assistant':
msgs[-1]['content'] += '\n\n' + tool_response_text
else:
msgs.append({'role': 'assistant', 'content': tool_response_text})
else:
msgs.append({'role': 'assistant', 'content': tool_response_text})
107 changes: 107 additions & 0 deletions pylib/prompt_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# SPDX-FileCopyrightText: 2024-present Oori Data <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
# toolio.prompt_helper

import json

from toolio import LANG, model_flag, DEFAULT_FLAGS
from toolio.http_schematics import V1ChatMessage


def enrich_chat_for_tools(msgs, tool_prompt, model_flags):
'''
msgs - chat messages to augment
model_flags - flags indicating the expectations of the hosted LLM
'''
# Add prompting (system prompt, if permitted) instructing the LLM to use tools
if model_flag.NO_SYSTEM_ROLE in model_flags: # LLM supports system messages
msgs.insert(0, V1ChatMessage(role='system', content=tool_prompt))
elif model_flag.USER_ASSISTANT_ALT in model_flags: # LLM insists that user and assistant messages must alternate
msgs[0].content = msgs[0].content=tool_prompt + '\n\n' + msgs[0].content
else:
msgs.insert(0, V1ChatMessage(role='user', content=tool_prompt))


def set_tool_response(msgs, tool_call_id, tool_name, tool_result, model_flags=DEFAULT_FLAGS):
'''
msgs - chat messages to augment
tool_response - response generatded by selected tool
model_flags - flags indicating the expectations of the hosted LLM
'''
# XXX: model_flags = None ⇒ assistant-style tool response. Is this the default we want?
if model_flag.TOOL_RESPONSE in model_flags:
msgs.append({
'tool_call_id': tool_call_id,
'role': 'tool',
'name': tool_name,
'content': tool_result,
})
else:
# FIXME: Separate out natural language
tool_response_text = f'Result of the call to {tool_name}: {tool_result}'
if model_flag.USER_ASSISTANT_ALT in model_flags:
# If there is already an assistant msg from tool-calling, merge it
if msgs[-1]['role'] == 'assistant':
msgs[-1]['content'] += '\n\n' + tool_response_text
else:
msgs.append({'role': 'assistant', 'content': tool_response_text})
else:
msgs.append({'role': 'assistant', 'content': tool_response_text})


def single_tool_prompt(tool, tool_schema, leadin=None):
leadin = leadin or LANG['one_tool_prompt_leadin']
return f'''
{leadin} {tool["name"]}: {tool["description"]}
{LANG["one_tool_prompt_schemalabel"]}: {json.dumps(tool_schema)}
{LANG["one_tool_prompt_tail"]}
'''


def multiple_tool_prompt(tools, tool_schemas, separator='\n', leadin=None):
leadin = leadin or LANG['multi_tool_prompt_leadin']
toollist = separator.join(
[f'\nTool {tool["name"]}: {tool["description"]}\nInvocation schema: {json.dumps(tool_schema)}\n'
for tool, tool_schema in zip(tools, tool_schemas) ])
return f'''
{leadin}
{toollist}
{LANG["multi_tool_prompt_tail"]}
'''


def select_tool_prompt(self, tools, tool_schemas, separator='\n', leadin=None):
leadin = leadin or LANG['multi_tool_prompt_leadin']
toollist = separator.join(
[f'\n{LANG["select_tool_prompt_toollabel"]} {tool["name"]}: {tool["description"]}\n'
f'{LANG["select_tool_prompt_schemalabel"]}: {json.dumps(tool_schema)}\n'
for tool, tool_schema in zip(tools, tool_schemas) ])
return f'''
{leadin}
{toollist}
{LANG["select_tool_prompt_tail"]}
'''


def process_tool_sysmsg(tools, leadin=None):
# print(f'{tools=} | {leadin=}')
function_schemas = [
{
'type': 'object',
'properties': {
'name': {'type': 'const', 'const': fn['name']},
'arguments': fn['parameters'],
},
'required': ['name', 'arguments'],
}
for fn in tools
]
if len(function_schemas) == 1:
schema = function_schemas[0]
tool_sysmsg = single_tool_prompt(tools[0], function_schemas[0], leadin=leadin)
else:
schema = {'type': 'array', 'items': {'anyOf': function_schemas}}
tool_sysmsg = multiple_tool_prompt(tools, function_schemas, leadin=leadin)
# print(f'{tool_sysmsg=}')
return schema, tool_sysmsg
Loading

0 comments on commit 66c5f1b

Please sign in to comment.