Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ollama images #59

Merged
merged 7 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
- name: Pull models
run: |
ollama pull mistral
ollama pull bakllava
OLLAMA_HOST=127.0.0.1:11435 ollama pull qwen2:0.5b
- name: Set up MATLAB
uses: matlab-actions/setup-matlab@v2
Expand Down
36 changes: 35 additions & 1 deletion azureChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@
if isstring(messages) && isscalar(messages)
messagesStruct = {struct("role", "user", "content", messages)};
else
messagesStruct = messages.Messages;
messagesStruct = this.encodeImages(messages.Messages);
end

if ~isempty(this.SystemPrompt)
Expand Down Expand Up @@ -251,6 +251,40 @@ function mustBeValidFunctionCall(this, functionCall)
end

end

function messageStruct = encodeImages(~, messageStruct)
for k=1:numel(messageStruct)
if isfield(messageStruct{k},"images")
images = messageStruct{k}.images;
detail = messageStruct{k}.image_detail;
messageStruct{k} = rmfield(messageStruct{k},["images","image_detail"]);
messageStruct{k}.content = ...
{struct("type","text","text",messageStruct{k}.content)};
for img = images(:).'
if startsWith(img,("https://"|"http://"))
s = struct( ...
"type","image_url", ...
"image_url",struct("url",img));
else
[~,~,ext] = fileparts(img);
MIMEType = "data:image/" + erase(ext,".") + ";base64,";
% Base64 encode the image using the given MIME type
fid = fopen(img);
im = fread(fid,'*uint8');
fclose(fid);
b64 = matlab.net.base64encode(im);
s = struct( ...
"type","image_url", ...
"image_url",struct("url",MIMEType + b64));
end

s.image_url.detail = detail;

messageStruct{k}.content{end+1} = s;
end
end
end
end
end
end

Expand Down
13 changes: 13 additions & 0 deletions doc/Azure.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,19 @@ txt = generate(chat,"What is Model-Based Design and how is it related to Digital
% Should stream the response token by token
```

## Understanding the content of an image

You can use gpt-4o, gpt-4o-mini, or gpt-4-turbo to experiment with image understanding.
```matlab
chat = azureChat("You are an AI assistant.",Deployment="gpt-4o");
image_path = "peppers.png";
messages = messageHistory;
messages = addUserMessageWithImages(messages,"What is in the image?",image_path);
[txt,response] = generate(chat,messages,MaxNumTokens=4096);
txt
% outputs a description of the image
```

## Calling MATLAB functions with the API

Optionally, `Tools=functions` can be used to provide function specifications to the API. The purpose of this is to enable models to generate function arguments which adhere to the provided specifications.
Expand Down
13 changes: 13 additions & 0 deletions doc/Ollama.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,19 @@ txt = generate(chat,"What is Model-Based Design and how is it related to Digital
% Should stream the response token by token
```

## Understanding the content of an image

You can use multimodal models like `llava` to experiment with image understanding.
```matlab
chat = ollamaChat("llava");
image_path = "peppers.png";
messages = messageHistory;
messages = addUserMessageWithImages(messages,"What is in the image?",image_path);
[txt,response] = generate(chat,messages,MaxNumTokens=4096);
txt
% outputs a description of the image
```

## Establishing a connection to remote LLMs using Ollama

To connect to a remote Ollama server, use the `Endpoint` name-value pair. Include the server name and port number. Ollama starts on 11434 by default.
Expand Down
7 changes: 4 additions & 3 deletions doc/OpenAI.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,15 @@ You can extract the arguments and write the data to a table, for example.

## Understanding the content of an image

You can use gpt-4-turbo to experiment with image understanding.
You can use gpt-4o, gpt-4o-mini, or gpt-4-turbo to experiment with image understanding.
```matlab
chat = openAIChat("You are an AI assistant.", ModelName="gpt-4-turbo");
chat = openAIChat("You are an AI assistant.");
image_path = "peppers.png";
messages = messageHistory;
messages = addUserMessageWithImages(messages,"What is in the image?",image_path);
[txt,response] = generate(chat,messages,MaxNumTokens=4096);
% Should output the description of the image
txt
% outputs a description of the image
```

## Obtaining embeddings
Expand Down
29 changes: 3 additions & 26 deletions messageHistory.m
Original file line number Diff line number Diff line change
Expand Up @@ -111,32 +111,9 @@
nvp.Detail string {mustBeMember(nvp.Detail,["low","high","auto"])} = "auto"
end

newMessage = struct("role", "user", "content", []);
newMessage.content = {struct("type","text","text",string(content))};
for img = images(:).'
if startsWith(img,("https://"|"http://"))
s = struct( ...
"type","image_url", ...
"image_url",struct("url",img));
else
[~,~,ext] = fileparts(img);
MIMEType = "data:image/" + erase(ext,".") + ";base64,";
% Base64 encode the image using the given MIME type
fid = fopen(img);
im = fread(fid,'*uint8');
fclose(fid);
b64 = matlab.net.base64encode(im);
s = struct( ...
"type","image_url", ...
"image_url",struct("url",MIMEType + b64));
end

s.image_url.detail = nvp.Detail;

newMessage.content{end+1} = s;
this.Messages{end+1} = newMessage;
end

newMessage = struct("role", "user", "content", string(content), ...
"images", images, "image_detail", nvp.Detail);
this.Messages{end+1} = newMessage;
end

function this = addToolMessage(this, id, name, content)
Expand Down
24 changes: 23 additions & 1 deletion ollamaChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
if isstring(messages) && isscalar(messages)
messagesStruct = {struct("role", "user", "content", messages)};
else
messagesStruct = messages.Messages;
messagesStruct = this.encodeImages(messages.Messages);
end

if ~isempty(this.SystemPrompt)
Expand All @@ -160,6 +160,28 @@
end
end

methods (Access=private)
function messageStruct = encodeImages(~, messageStruct)
for k=1:numel(messageStruct)
if isfield(messageStruct{k},"images")
images = messageStruct{k}.images;
% detail = messageStruct{k}.image_detail;
messageStruct{k} = rmfield(messageStruct{k},["images","image_detail"]);
imgs = cell(size(images));
for n = 1:numel(images)
img = images(n);
% Base64 encode the image
fid = fopen(img);
im = fread(fid,'*uint8');
fclose(fid);
imgs{n} = matlab.net.base64encode(im);
end
messageStruct{k}.images = imgs;
end
end
end
end

methods(Static)
function mdls = models
%ollamaChat.models - return models available on Ollama server
Expand Down
36 changes: 35 additions & 1 deletion openAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@
if isstring(messages) && isscalar(messages)
messagesStruct = {struct("role", "user", "content", messages)};
else
messagesStruct = messages.Messages;
messagesStruct = this.encodeImages(messages.Messages);
end

llms.openai.validateMessageSupported(messagesStruct{end}, this.ModelName);
Expand Down Expand Up @@ -230,6 +230,40 @@ function mustBeValidFunctionCall(this, functionCall)
end

end

function messageStruct = encodeImages(~, messageStruct)
for k=1:numel(messageStruct)
if isfield(messageStruct{k},"images")
images = messageStruct{k}.images;
detail = messageStruct{k}.image_detail;
messageStruct{k} = rmfield(messageStruct{k},["images","image_detail"]);
messageStruct{k}.content = ...
{struct("type","text","text",messageStruct{k}.content)};
for img = images(:).'
if startsWith(img,("https://"|"http://"))
s = struct( ...
"type","image_url", ...
"image_url",struct("url",img));
else
[~,~,ext] = fileparts(img);
MIMEType = "data:image/" + erase(ext,".") + ";base64,";
% Base64 encode the image using the given MIME type
fid = fopen(img);
im = fread(fid,'*uint8');
fclose(fid);
b64 = matlab.net.base64encode(im);
s = struct( ...
"type","image_url", ...
"image_url",struct("url",MIMEType + b64));
end

s.image_url.detail = detail;

messageStruct{k}.content{end+1} = s;
end
end
end
end
end
end

Expand Down
29 changes: 29 additions & 0 deletions tests/tazureChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,26 @@ function generateMultipleResponses(testCase)
testCase.verifySize(response.Body.Data.choices,[3,1]);
end

function generateWithImage(testCase)
chat = azureChat(Deployment="gpt-4o");
image_path = "peppers.png";
emptyMessages = messageHistory;
messages = addUserMessageWithImages(emptyMessages,"What is in the image?",image_path);

text = generate(chat,messages);
testCase.verifyThat(text,matlab.unittest.constraints.ContainsSubstring("pepper"));
end

function generateWithMultipleImages(testCase)
import matlab.unittest.constraints.ContainsSubstring
chat = azureChat(Deployment="gpt-4o");
image_path = "peppers.png";
emptyMessages = messageHistory;
messages = addUserMessageWithImages(emptyMessages,"Compare these images.",[image_path,image_path]);

text = generate(chat,messages);
testCase.verifyThat(text,ContainsSubstring("same") | ContainsSubstring("identical"));
end

function doReturnErrors(testCase)
testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
Expand All @@ -65,6 +85,15 @@ function doReturnErrors(testCase)
testCase.verifyError(@() generate(chat,wayTooLong), "llms:apiReturnedError");
end

function generateWithImageErrorsForGpt35(testCase)
chat = azureChat;
image_path = "peppers.png";
emptyMessages = messageHistory;
messages = addUserMessageWithImages(emptyMessages,"What is in the image?",image_path);

testCase.verifyError(@() generate(chat,messages), "llms:apiReturnedError");
end

function seedFixesResult(testCase)
testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
chat = azureChat;
Expand Down
10 changes: 10 additions & 0 deletions tests/tollamaChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ function seedFixesResult(testCase)
testCase.verifyEqual(response1,response2);
end

function generateWithImages(testCase)
chat = ollamaChat("bakllava");
image_path = "peppers.png";
emptyMessages = messageHistory;
messages = addUserMessageWithImages(emptyMessages,"What is in the image?",image_path);

text = generate(chat,messages);
testCase.verifyThat(text,matlab.unittest.constraints.ContainsSubstring("pepper"));
end

function streamFunc(testCase)
function seen = sf(str)
persistent data;
Expand Down
13 changes: 12 additions & 1 deletion tests/topenAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ function generateWithToolsAndStreamFunc(testCase)
testCase.verifyThat(data,HasField("explanation"));
end

function generateWithImages(testCase)
function generateWithImage(testCase)
chat = openAIChat;
image_path = "peppers.png";
emptyMessages = messageHistory;
Expand All @@ -183,6 +183,17 @@ function generateWithImages(testCase)
testCase.verifyThat(text,matlab.unittest.constraints.ContainsSubstring("pepper"));
end

function generateWithMultipleImages(testCase)
import matlab.unittest.constraints.ContainsSubstring
chat = openAIChat;
image_path = "peppers.png";
emptyMessages = messageHistory;
messages = addUserMessageWithImages(emptyMessages,"Compare these images.",[image_path,image_path]);

text = generate(chat,messages);
testCase.verifyThat(text,ContainsSubstring("same") | ContainsSubstring("identical"));
ccreutzi marked this conversation as resolved.
Show resolved Hide resolved
end

function invalidInputsGenerate(testCase, InvalidGenerateInput)
f = openAIFunction("validfunction");
chat = openAIChat(Tools=f, APIKey="this-is-not-a-real-key");
Expand Down