Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Toolio integration #1

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
15 changes: 7 additions & 8 deletions src/generative_redfoot/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import json
import re

from .object_pdl_model import PDLModel, PDLProgram, ParseDispatcher, PDFRead
from .utils import truncate_long_text
from .object_pdl_model import PDLModel, PDLProgram, ParseDispatcher, PDFRead, PDLRepeat, PDLText, PDLRead
from .extensions.wordloom import WorldLoomRead
from .extensions.toolio import ToolioCompletion
from pyarrow.lib import Mapping
from transformers import PreTrainedTokenizer
from typing import Tuple, Dict, List

def truncate_long_text(text, max_length=200):
return (text[:max_length] + '..') if len(text) > max_length else text

@click.command()
@click.option('-t', '--temperature', default=1, type=float)
@click.option('-rp', '--repetition-penalty', default=0, type=float,
Expand All @@ -26,7 +26,7 @@ def main(temperature, repetition_penalty, top_k, max_tokens, min_p, verbose, var
from mlx_lm.utils import load, generate
from mlx_lm.sample_utils import make_sampler, make_logits_processors
import mlx.nn as nn
from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache

start_marker = '<s>'
end_marker = '</s>'
Expand Down Expand Up @@ -159,9 +159,8 @@ def dispatch_check(item: Mapping, program: PDLProgram):
return MLXAPSModel(item, program)

dispatcher = ParseDispatcher()
dispatcher.DISPATCH_RESOLUTION_ORDER[-1] = MLXModelEvaluation
dispatcher.DISPATCH_RESOLUTION_ORDER.append(MLXAPSModel)
dispatcher.DISPATCH_RESOLUTION_ORDER.append(PDFRead)
dispatcher.DISPATCH_RESOLUTION_ORDER = [PDLRead, WorldLoomRead, ToolioCompletion, PDLRepeat, PDLText,
MLXModelEvaluation, MLXAPSModel, PDFRead]
with open(pdl_file, "r") as file:
program = PDLProgram(yaml.safe_load(file), dispatcher=dispatcher, initial_context=dict(variables))
program.execute(verbose=verbose)
Expand Down
Empty file.
79 changes: 79 additions & 0 deletions src/generative_redfoot/extensions/toolio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from ..object_pdl_model import PDLObject, PDLStructuredBlock, PDLProgram
from ..utils import truncate_messages
import json
from typing import Mapping, Dict
import asyncio

"""
from toolio.llm_helper import local_model_runner
toolio_mm = local_model_runner('..')

async def [..]([..]):
prompt = [..]
done = False
msgs = [{'role': 'user', 'content': prompt}]
while not done:
rt = await tmm.complete(msgs, json_schema=[..], max_tokens=512)
obj = json.loads(rt)
# print('DEBUG return object:', obj)


"""

class ToolioCompletion(PDLObject, PDLStructuredBlock):
"""
PDL block for structured LLM response generation via Toolio + MLX

from toolio.llm_helper import local_model_runner
toolio_mm = local_model_runner('..')

async def [..]([..]):
prompt = [..]
done = False
msgs = [{'role': 'user', 'content': prompt}]
while not done:
rt = await tmm.complete(msgs, json_schema=[..], max_tokens=512)
obj = json.loads(rt)
# print('DEBUG return object:', obj)
"""
def __init__(self, pdl_block: Mapping, program: PDLProgram):
self.program = program
self.model = pdl_block["structured_output"]
self.insert_schema = pdl_block["insert_schema"]
self.schema_file = pdl_block["schema_file"]
self.max_tokens = pdl_block.get("max_tokens", 512)
self.temperature = pdl_block.get("temperature", .1)
self.input = self.program.dispatcher.handle(pdl_block["input"], self.program) if "input" in pdl_block else None
self._get_common_attributes(pdl_block)

def __repr__(self):
return f"ToolioCompletion(according to '{self.schema_file}' and up to {self.max_tokens:,} tokens)"

def execute(self, context: Dict, verbose: bool = False):
source_phrase = ""
if self.input:
source_phrase = f" from {self.input}"
if verbose:
print(f"Running Toolio completion according to '{self.schema_file}', using {truncate_messages(context)}"
f"{source_phrase} (max of {self.max_tokens:,} tokens)")
if self.input:
self.input.execute(context, verbose=verbose)
asyncio.run(self.toolio_completion(context, verbose))

async def toolio_completion(self, context: Dict, verbose: bool = False) -> str:
from toolio.llm_helper import local_model_runner
toolio_mm = local_model_runner(self.model)
msgs = context["_"]

with open(self.schema_file, mode='r') as schema_file:
self._handle_execution_contribution(await toolio_mm.complete(msgs,
json_schema=schema_file.read(),
max_tokens=self.max_tokens,
temperature=self.temperature,
insert_schema=self.insert_schema),
context)

@staticmethod
def dispatch_check(item: Mapping, program: PDLProgram):
if "structured_output" in item:
return ToolioCompletion(item, program)
63 changes: 63 additions & 0 deletions src/generative_redfoot/extensions/wordloom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import yaml
from ..object_pdl_model import PDLObject, PDLStructuredBlock, PDL3, PDLProgram
from typing import Mapping, Dict, Any

class WorldLoomRead(PDLObject, PDLStructuredBlock):
"""
PDL block for reading sections for a prompt from a Worldloom (TOML / YAML) file using ogbujipt.word_loom

Example:
>>> p = PDLProgram(yaml.safe_load(PDL3))
>>> p.cache
'prompt_cache.safetensors'
>>> p.text[0]
Wordloom('question answer' from file.loom [outputs to context as user])

"""

def __init__(self, pdl_block: Mapping, program: PDLObject):
self.program = program
self.loom_file = pdl_block["read_from_wordloom"]
self.language_items = pdl_block["items"]
self._get_common_attributes(pdl_block)

def __repr__(self):
return f"Wordloom('{self.language_items}' from {self.loom_file} [{self.descriptive_text()}])"

def execute(self, context: Dict, verbose: bool = False):
from ogbujipt import word_loom
with open(self.loom_file, mode='rb') as fp:
loom = word_loom.load(fp)
items = self.language_items.split(' ')
if verbose:
print(f"Expanding {items} from {self.loom_file}")
content = '\n'.join([WorldLoomRead.get_loom_entry(loom[name], context) for name in items])
self._handle_execution_contribution(content, context)

@staticmethod
def get_loom_entry(loom_entry:Any, context: Mapping) -> str:
"""
Processes a language_item by formatting it with context-specific marker substitutions
if markers are present in the language_item. If no markers are available, the original
language_item is returned as is.

:param loom_entry: A wordloom `language_item` object that contains potential markers to be
substituted and formatted with values from the context.
:param context: A dictionary-like object (`Mapping`) that holds marker-to-value
mappings to be used for substitutions in the given loom_entry.
:return: Returns a formatted string if markers are found and substitutions can be
applied; otherwise, returns the unprocessed `language_item` as is.

"""
if loom_entry.markers:
marker_kwargs = {}
for marker in loom_entry.markers:
marker_kwargs[marker] = context[marker]
return loom_entry.format(**marker_kwargs)
else:
return loom_entry

@staticmethod
def dispatch_check(item: Mapping, program: PDLObject):
if "read_from_wordloom" in item:
return WorldLoomRead(item, program)
118 changes: 35 additions & 83 deletions src/generative_redfoot/object_pdl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _handle_execution_contribution(self, content: Union[List, str], context: Dic
if content:
msg = {"role": self.role, "content": content}
if "result" in self.contribute:
pprint(content)
print(content)
if "context" in self.contribute:
context.setdefault('_', []).append(msg)

Expand Down Expand Up @@ -188,7 +188,7 @@ class PDLText(TextCollator, PDLStructuredBlock):

>>> p.execute(verbose=True)
Executing: program
bar
'bar'
>>> p.evaluation_environment
{'_': [{'role': 'system', 'content': 'foo'}, {'role': 'user', 'content': 'bar'}]}

Expand Down Expand Up @@ -217,25 +217,38 @@ def __len__(self):
def execute(self, context: Dict, verbose: bool = False):
""""""
content = ''
for item in self.content:
if isinstance(item, str):
content += item
else:
result = item.execute(context, verbose=verbose)
if result is not None:
content += result
merged_context = []
previous_item = None
for idx, item in enumerate(context.get("_", [])):
if idx > 0 and item["role"] == previous_item["role"]:
previous_item["content"] += item["content"]
else:
merged_context.append(item)
previous_item = item
context["_"] = merged_context

if isinstance(self.content, str):
self.merge_content(context, self.content)
if "result" in self.contribute:
pprint(self.content)
else:
for item in self.content:
if isinstance(item, str):
self.merge_content(context, item)
if "result" in self.contribute:
pprint(content)
else:
result = item.execute(context, verbose=verbose)
if result is not None:
self.merge_content(context, result)
merged_context = []
previous_item = None
for idx, item in enumerate(context.get("_", [])):
if idx > 0 and item["role"] == previous_item["role"]:
previous_item["content"] += item["content"]
else:
merged_context.append(item)
previous_item = item
context["_"] = merged_context
self._handle_execution_contribution(content, context)

def merge_content(self, context, item):
messages = context.setdefault('_', [])
if messages and [m for m in messages if m['role'] == "user"]:
messages[-1]["content"] += item
else:
messages.append({"role": self.role, "content": item})

@staticmethod
def dispatch_check(item: Mapping, program: PDLObject):
if "text" in item:
Expand Down Expand Up @@ -279,66 +292,6 @@ def dispatch_check(item: Mapping, program: PDLObject):
if "read" in item:
return PDLRead(item, program)

class WorldLoomRead(PDLObject, PDLStructuredBlock):
"""
PDL block for reading sections for a prompt from a Worldloom (TOML / YAML) file using ogbujipt.word_loom

Example:
>>> p = PDLProgram(yaml.safe_load(PDL3))
>>> p.cache
'prompt_cache.safetensors'
>>> p.text[0]
Wordloom('question answer' from file.loom [outputs to context as user])

"""

def __init__(self, pdl_block: Mapping, program: PDLObject):
self.program = program
self.loom_file = pdl_block["read_from_wordloom"]
self.language_items = pdl_block["items"]
self._get_common_attributes(pdl_block)

def __repr__(self):
return f"Wordloom('{self.language_items}' from {self.loom_file} [{self.descriptive_text()}])"

def execute(self, context: Dict, verbose: bool = False):
from ogbujipt import word_loom
with open(self.loom_file, mode='rb') as fp:
loom = word_loom.load(fp)
items = self.language_items.split(' ')
if verbose:
print(f"Expanding {items} from {self.loom_file}")
content = '\n'.join([WorldLoomRead.get_loom_entry(loom[name], context) for name in items])
self._handle_execution_contribution(content, context)

@staticmethod
def get_loom_entry(loom_entry:word_loom.language_item, context: Mapping) -> Union[str, word_loom]:
"""
Processes a language_item by formatting it with context-specific marker substitutions
if markers are present in the language_item. If no markers are available, the original
language_item is returned as is.

:param loom_entry: A wordloom `language_item` object that contains potential markers to be
substituted and formatted with values from the context.
:param context: A dictionary-like object (`Mapping`) that holds marker-to-value
mappings to be used for substitutions in the given loom_entry.
:return: Returns a formatted string if markers are found and substitutions can be
applied; otherwise, returns the unprocessed `language_item` as is.

"""
if loom_entry.markers:
marker_kwargs = {}
for marker in loom_entry.markers:
marker_kwargs[marker] = context[marker]
return loom_entry.format(**marker_kwargs)
else:
return loom_entry

@staticmethod
def dispatch_check(item: Mapping, program: PDLObject):
if "read_from_wordloom" in item:
return WorldLoomRead(item, program)

class PDLRepeat(PDLObject, PDLStructuredBlock):
def __init__(self, content: Mapping, program):
self.program = program
Expand Down Expand Up @@ -445,7 +398,7 @@ def dispatch_check(item: Mapping, program: PDLObject):
return PDFRead(item, program)

class ParseDispatcher:
DISPATCH_RESOLUTION_ORDER = [PDLRead, WorldLoomRead, PDLRepeat, PDLText, PDLModel]
DISPATCH_RESOLUTION_ORDER = [PDLRead, PDLRepeat, PDLText, PDLModel]

def handle(self, item: Mapping, program: PDLObject) -> PDLObject:
if isinstance(item, str):
Expand Down Expand Up @@ -501,9 +454,8 @@ class PDLProgram(PDLObject, PDLStructuredBlock):
>>> program.text[0].parameters
{'temperature': 0.6, 'min_p': 0.03, 'max_tokens': 600}

>>> program.execute(verbose=True)
Executing: program
.. model response ..
>>> program.execute()
'.. model response ..'

>>> program.evaluation_environment
{'_': [{'role': 'assistant', 'content': '.. model response ..'}]}
Expand Down
7 changes: 7 additions & 0 deletions src/generative_redfoot/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from pprint import pprint

def truncate_long_text(text, max_length=200):
return (text[:max_length] + '..') if len(text) > max_length else text

def truncate_messages(messages):
return [{k: v if k == "role" else truncate_long_text(v)} for i in messages for k, v in i.items()]
Loading