Skip to content

Commit

Permalink
Structured ollama
Browse files Browse the repository at this point in the history
* Activate structured output for ollamaChat
* Update documentation
  • Loading branch information
ccreutzi committed Dec 13, 2024
1 parent 3b1db9d commit 44a3679
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 160 deletions.
9 changes: 9 additions & 0 deletions +llms/+internal/callOllamaChatAPI.m
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@
parameters.stream = ~isempty(nvp.StreamFun);

options = struct;

if strcmp(nvp.ResponseFormat,"json")
parameters.format = struct('type','json_object');
elseif isstruct(nvp.ResponseFormat)
parameters.format = llms.internal.jsonSchemaFromPrototype(nvp.ResponseFormat);
elseif startsWith(string(nvp.ResponseFormat), asManyOfPattern(whitespacePattern)+"{")
parameters.format = llms.internal.verbatimJSON(nvp.ResponseFormat);
end

if ~isempty(nvp.Seed)
options.seed = nvp.Seed;
end
Expand Down
2 changes: 1 addition & 1 deletion +llms/+internal/useSameFieldTypes.m
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
case "struct"
prototype = prototype(1);
if isscalar(data)
if isequal(fieldnames(data),fieldnames(prototype))
if isequal(sort(fieldnames(data)),sort(fieldnames(prototype)))
for field_c = fieldnames(data).'
field = field_c{1};
data.(field) = alignTypes(data.(field),prototype.(field));
Expand Down
1 change: 1 addition & 0 deletions +llms/+utils/errorMessageCatalog.m
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,5 @@
catalog("llms:stream:responseStreamer:InvalidInput") = "Input does not have the expected json format, got ""{1}"".";
catalog("llms:unsupportedDatatypeInPrototype") = "Invalid data type ''{1}'' in prototype. Prototype must be a struct, composed of numerical, string, logical, categorical, or struct.";
catalog("llms:incorrectResponseFormat") = "Invalid response format. Response format must be ""text"", ""json"", a struct, or a string with a JSON Schema definition.";
catalog("llms:OllamaStructuredOutputNeeds05") = "Structured output is not supported for Ollama version {1}. Use version 0.5.0 or newer.";
end
9 changes: 9 additions & 0 deletions +llms/+utils/requestsStructuredOutput.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
function tf = requestsStructuredOutput(format)
% This function is undocumented and will change in a future release

% Simple function to check if requested format triggers structured output

% Copyright 2024 The MathWorks, Inc.
tf = isstruct(format) || startsWith(format,asManyOfPattern(whitespacePattern)+"{");
end

25 changes: 21 additions & 4 deletions doc/functions/ollamaChat.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,23 +139,40 @@ If the server does not respond within the timeout, then the function throws an e

### `ResponseFormat` — Response format

`"text"` (default) | `"json"`
`"text"` (default) | `"json"` | string scalar | structure array


After construction, this property is read\-only.


Format of generated output.
Format of the `generatedOutput` output argument of the `generate` function. You can request unformatted output, JSON mode, or structured output.


If you set the response format to `"text"`, then the generated output is a string.
#### Unformatted Output


If you set the response format to `"json"`, then the generated output is a string containing JSON encoded data.
If you set the response format to `"text"`, then the generated output is an unformatted string.


#### JSON Mode


If you set the response format to `"json"`, then the generated output is a formatted string containing JSON encoded data.


To configure the format of the generated JSON file, describe the format using natural language and provide it to the model either in the system prompt or as a user message. The prompt or message describing the format must contain the word `"json"` or `"JSON"`.

#### Structured Output


This option is only supported for Ollama version 0.5.0 and later.


To ensure that the model follows the required format, use structured output. To do this, set `ReponseFormat` to:

- A string scalar containing a valid JSON Schema.
- A structure array containing an example that adheres to the required format, for example: `ResponseFormat=struct("Name","Rudolph","NoseColor",[255 0 0])`

# Other Properties
### `SystemPrompt` — System prompt

Expand Down
33 changes: 28 additions & 5 deletions ollamaChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,9 @@
% value is CHAT.StopSequences.
% Example: ["The end.", "And that's all she wrote."]
%
%
% ResponseFormat - The format of response the model returns.
% The default value is CHAT.ResponseFormat.
% "text" (default) | "json"
% ResponseFormat - The format of response the call returns.
% Default value is CHAT.ResponseFormat.
% "text" | "json" | struct | string with JSON Schema
%
% StreamFun - Function to callback when streaming the
% result. The default value is CHAT.StreamFun.
Expand All @@ -193,7 +192,7 @@
nvp.MinP {llms.utils.mustBeValidProbability} = this.MinP
nvp.TopK (1,1) {mustBeReal,mustBePositive} = this.TopK
nvp.StopSequences {llms.utils.mustBeValidStop} = this.StopSequences
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = this.ResponseFormat
nvp.ResponseFormat {llms.utils.mustBeResponseFormat} = this.ResponseFormat
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = this.TimeOut
nvp.TailFreeSamplingZ (1,1) {mustBeReal} = this.TailFreeSamplingZ
nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')}
Expand Down Expand Up @@ -234,9 +233,16 @@
end

if isfield(response.Body.Data,"error")
[versionStr, versionList] = serverVersion(nvp.Endpoint);
if llms.utils.requestsStructuredOutput(nvp.ResponseFormat) && ...
~versionIsAtLeast(versionList, [0,5,0])
error("llms:OllamaStructuredOutputNeeds05",llms.utils.errorMessageCatalog.getMessage("llms:OllamaStructuredOutputNeeds05", versionStr));
end
err = response.Body.Data.error;
error("llms:apiReturnedError",llms.utils.errorMessageCatalog.getMessage("llms:apiReturnedError",err));
end

text = llms.internal.reformatOutput(text,nvp.ResponseFormat);
end
end

Expand Down Expand Up @@ -310,3 +316,20 @@ function mustBeIntegerOrEmpty(value)
mustBeInteger(value)
end
end

function [versionStr, versionList] = serverVersion(endpoint)
URL = endpoint + "/api/version";
if ~startsWith(URL,"http")
URL = "http://" + URL;
end
versionStr = webread(URL).version;
versionList = split(versionStr,'.');
versionList = str2double(versionList);
end

function tf = versionIsAtLeast(version,minVersion)
tf = version(1) > minVersion(1) || ...
(version(1) == minVersion(1) && (...
version(2) > minVersion(2) || ...
(version(2) == minVersion(2) && version(3) >= minVersion(3))));
end
149 changes: 1 addition & 148 deletions tests/hopenAIChat.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
classdef (Abstract) hopenAIChat < matlab.unittest.TestCase
classdef (Abstract) hopenAIChat < hstructuredOutput
% Tests for OpenAI-based chats (openAIChat, azureChat)

% Copyright 2023-2024 The MathWorks, Inc.
Expand All @@ -17,8 +17,6 @@
constructor
defaultModel
visionModel
structuredModel
noStructuredOutputModel
end

methods(Test)
Expand Down Expand Up @@ -195,66 +193,6 @@ function generateOverridesProperties(testCase)
testCase.verifyThat(text, EndsWithSubstring("3, "));
end

function generateWithStructuredOutput(testCase)
import matlab.unittest.constraints.IsEqualTo
import matlab.unittest.constraints.StartsWithSubstring
res = generate(testCase.structuredModel,"Which animal produces honey?",...
ResponseFormat = struct(commonName = "dog", scientificName = "Canis familiaris"));
testCase.assertClass(res,"struct");
testCase.verifySize(fieldnames(res),[2,1]);
testCase.verifyThat(res.commonName, IsEqualTo("Honeybee") | IsEqualTo("Honey bee") | IsEqualTo("Honey Bee"));
testCase.verifyThat(res.scientificName, StartsWithSubstring("Apis"));
end

function generateListWithStructuredOutput(testCase)
prototype = struct("plantName",{"appletree","pear"}, ...
"fruit",{"apple","pear"}, ...
"edible",[true,true], ...
"ignore", missing);
res = generate(testCase.structuredModel,"What is harvested in August?", ResponseFormat = prototype);
testCase.verifyCompatibleStructs(res, prototype);
end

function generateWithNestedStructs(testCase)
stepsPrototype = struct("explanation",{"a","b"},"assumptions",{"a","b"});
prototype = struct("steps",stepsPrototype,"final_answer","a");
res = generate(testCase.structuredModel,"What is the positive root of x^2-2*x+1?", ...
ResponseFormat=prototype);
testCase.verifyCompatibleStructs(res,prototype);
end

function incompleteJSONResponse(testCase)
country = ["USA";"UK"];
capital = ["Washington, D.C.";"London"];
population = [345716792;69203012];
prototype = struct("country",country,"capital",capital,"population",population);

testCase.verifyError(@() generate(testCase.structuredModel, ...
"What are the five largest countries whose English names" + ...
" start with the letter A?", ...
ResponseFormat = prototype, MaxNumTokens=3), "llms:apiReturnedIncompleteJSON");
end

function generateWithExplicitSchema(testCase)
import matlab.unittest.constraints.IsSameSetAs
schema = iGetSchema();

genUser = generate(testCase.structuredModel,"Create a sample user",ResponseFormat=schema);
genAddress = generate(testCase.structuredModel,"Create a sample address",ResponseFormat=schema);

testCase.verifyClass(genUser,"string");
genUserDecoded = jsondecode(genUser);
testCase.verifyClass(genUserDecoded.item,"struct");
testCase.verifyThat(fieldnames(genUserDecoded.item),...
IsSameSetAs({'name','age'}));

testCase.verifyClass(genAddress,"string");
genAddressDecoded = jsondecode(genAddress);
testCase.verifyClass(genAddressDecoded.item,"struct");
testCase.verifyThat(fieldnames(genAddressDecoded.item),...
IsSameSetAs({'number','street','city'}));
end

function invalidInputsGenerate(testCase, InvalidGenerateInput)
f = openAIFunction("validfunction");
chat = testCase.constructor(Tools=f, APIKey="this-is-not-a-real-key");
Expand Down Expand Up @@ -321,89 +259,4 @@ function keyNotFound(testCase)
testCase.verifyError(testCase.constructor, "llms:keyMustBeSpecified");
end
end

methods
function verifyCompatibleStructs(testCase,data,prototype)
import matlab.unittest.constraints.IsSameSetAs
testCase.assertClass(data,"struct");
if ~isscalar(data)
arrayfun(@(d) testCase.verifyCompatibleStructs(d,prototype), data);
return
end
testCase.assertClass(prototype,"struct");
if ~isscalar(prototype)
prototype = prototype(1);
end
testCase.assertThat(fieldnames(data),IsSameSetAs(fieldnames(prototype)));
for name = fieldnames(data).'
field = name{1};
testCase.verifyClass(data.(field),class(prototype.(field)));
if isstruct(data.(field))
testCase.verifyCompatibleStructs(data.(field),prototype.(field));
end
end
end
end
end

function str = iGetSchema()
% an example from https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
str = string(join({
'{'
' "type": "object",'
' "properties": {'
' "item": {'
' "anyOf": ['
' {'
' "type": "object",'
' "description": "The user object to insert into the database",'
' "properties": {'
' "name": {'
' "type": "string",'
' "description": "The name of the user"'
' },'
' "age": {'
' "type": "number",'
' "description": "The age of the user"'
' }'
' },'
' "additionalProperties": false,'
' "required": ['
' "name",'
' "age"'
' ]'
' },'
' {'
' "type": "object",'
' "description": "The address object to insert into the database",'
' "properties": {'
' "number": {'
' "type": "string",'
' "description": "The number of the address. Eg. for 123 main st, this would be 123"'
' },'
' "street": {'
' "type": "string",'
' "description": "The street name. Eg. for 123 main st, this would be main st"'
' },'
' "city": {'
' "type": "string",'
' "description": "The city of the address"'
' }'
' },'
' "additionalProperties": false,'
' "required": ['
' "number",'
' "street",'
' "city"'
' ]'
' }'
' ]'
' }'
' },'
' "additionalProperties": false,'
' "required": ['
' "item"'
' ]'
'}'
}, newline));
end
Loading

0 comments on commit 44a3679

Please sign in to comment.