Skip to content

Commit 8cb052d

Browse files
committed
Add Outlines
1 parent 3cc399d commit 8cb052d

File tree

3 files changed

+111
-0
lines changed

3 files changed

+111
-0
lines changed

outlines/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Outlines is a Generative Model Programming Framework."""
2+
23
import outlines.generate
34
import outlines.grammars
45
import outlines.models
@@ -7,6 +8,7 @@
78
from outlines.base import vectorize
89
from outlines.caching import clear_cache, disable_cache, get_cache
910
from outlines.function import Function
11+
from outlines.outline import Outline
1012
from outlines.prompts import prompt
1113

1214
__all__ = [
@@ -17,4 +19,5 @@
1719
"prompt",
1820
"vectorize",
1921
"grammars",
22+
"Outline",
2023
]

outlines/outline.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import ast
2+
from dataclasses import dataclass
3+
4+
5+
@dataclass
6+
class Outline:
7+
"""
8+
Outline is a class that creates a callable object to generate responses
9+
based on a given model and a prompt template.
10+
11+
Args:
12+
model: The model to be used for generating responses.
13+
template (function): A function that takes arguments and returns a prompt string.
14+
output_type: The expected output type of the generated response.
15+
16+
Example:
17+
from outlines import models
18+
19+
model = models.transformers("gpt2")
20+
21+
def template(a: int) -> str:
22+
return f"What is 2 times {a}?"
23+
24+
fn = Outline(model, template, int)
25+
26+
result = fn(3)
27+
print(result) # Expected output: 6
28+
"""
29+
30+
def __init__(self, model, template, output_type):
31+
self.model = model
32+
self.template = template
33+
self.output_type = output_type
34+
35+
def __call__(self, *args):
36+
# Generate the prompt using the template function
37+
prompt = self.template(*args)
38+
39+
# Generate the response using the model
40+
response = self.model.generate(prompt)
41+
42+
# Process the response to match the expected output type
43+
try:
44+
parsed_response = ast.literal_eval(response.strip())
45+
if isinstance(parsed_response, self.output_type):
46+
return parsed_response
47+
else:
48+
raise ValueError(
49+
f"Response type {type(parsed_response)} does not match expected type {self.output_type}"
50+
)
51+
except (ValueError, SyntaxError):
52+
raise ValueError(f"Unable to parse response: {response.strip()}")

tests/test_outline.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from unittest.mock import MagicMock
2+
3+
import pytest
4+
5+
from outlines.outline import Outline
6+
7+
8+
def test_outline_int_output():
9+
# Mock the model
10+
model = MagicMock()
11+
model.generate.return_value = "6"
12+
13+
# Define the template function
14+
def template(a: int) -> str:
15+
return f"What is 2 times {a}?"
16+
17+
# Create an instance of Outline
18+
fn = Outline(model, template, int)
19+
20+
# Test the callable object
21+
result = fn(3)
22+
assert result == 6
23+
24+
25+
def test_outline_str_output():
26+
# Mock the model
27+
model = MagicMock()
28+
model.generate.return_value = "'Hello, world!'"
29+
30+
# Define the template function
31+
def template(a: int) -> str:
32+
return f"Say hello {a} times"
33+
34+
# Create an instance of Outline
35+
fn = Outline(model, template, str)
36+
37+
# Test the callable object
38+
result = fn(1)
39+
assert result == "Hello, world!"
40+
41+
42+
def test_outline_invalid_output():
43+
# Mock the model
44+
model = MagicMock()
45+
model.generate.return_value = "not a number"
46+
47+
# Define the template function
48+
def template(a: int) -> str:
49+
return f"What is 2 times {a}?"
50+
51+
# Create an instance of Outline
52+
fn = Outline(model, template, int)
53+
54+
# Test the callable object with invalid output
55+
with pytest.raises(ValueError):
56+
fn(3)

0 commit comments

Comments
 (0)