Skip to content

Commit

Permalink
Merge pull request #59 from matlab-deep-learning/ollama-images
Browse files Browse the repository at this point in the history
Ollama images
  • Loading branch information
ccreutzi authored Jul 29, 2024
2 parents c607360 + 7aaadb4 commit c3bec04
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 33 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
- name: Pull models
run: |
ollama pull mistral
ollama pull bakllava
OLLAMA_HOST=127.0.0.1:11435 ollama pull qwen2:0.5b
- name: Set up MATLAB
uses: matlab-actions/setup-matlab@v2
Expand Down
36 changes: 35 additions & 1 deletion azureChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@
if isstring(messages) && isscalar(messages)
messagesStruct = {struct("role", "user", "content", messages)};
else
messagesStruct = messages.Messages;
messagesStruct = this.encodeImages(messages.Messages);
end

if ~isempty(this.SystemPrompt)
Expand Down Expand Up @@ -251,6 +251,40 @@ function mustBeValidFunctionCall(this, functionCall)
end

end

function messageStruct = encodeImages(~, messageStruct)
for k=1:numel(messageStruct)
if isfield(messageStruct{k},"images")
images = messageStruct{k}.images;
detail = messageStruct{k}.image_detail;
messageStruct{k} = rmfield(messageStruct{k},["images","image_detail"]);
messageStruct{k}.content = ...
{struct("type","text","text",messageStruct{k}.content)};
for img = images(:).'
if startsWith(img,("https://"|"http://"))
s = struct( ...
"type","image_url", ...
"image_url",struct("url",img));
else
[~,~,ext] = fileparts(img);
MIMEType = "data:image/" + erase(ext,".") + ";base64,";
% Base64 encode the image using the given MIME type
fid = fopen(img);
im = fread(fid,'*uint8');
fclose(fid);
b64 = matlab.net.base64encode(im);
s = struct( ...
"type","image_url", ...
"image_url",struct("url",MIMEType + b64));
end

s.image_url.detail = detail;

messageStruct{k}.content{end+1} = s;
end
end
end
end
end
end

Expand Down
13 changes: 13 additions & 0 deletions doc/Azure.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,19 @@ txt = generate(chat,"What is Model-Based Design and how is it related to Digital
% Should stream the response token by token
```

## Understanding the content of an image

You can use gpt-4o, gpt-4o-mini, or gpt-4-turbo to experiment with image understanding.
```matlab
chat = azureChat("You are an AI assistant.",Deployment="gpt-4o");
image_path = "peppers.png";
messages = messageHistory;
messages = addUserMessageWithImages(messages,"What is in the image?",image_path);
[txt,response] = generate(chat,messages,MaxNumTokens=4096);
txt
% outputs a description of the image
```

## Calling MATLAB functions with the API

Optionally, `Tools=functions` can be used to provide function specifications to the API. The purpose of this is to enable models to generate function arguments which adhere to the provided specifications.
Expand Down
18 changes: 18 additions & 0 deletions doc/Ollama.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,24 @@ txt = generate(chat,"What is Model-Based Design and how is it related to Digital
% Should stream the response token by token
```

## Understanding the content of an image

You can use multimodal models like `llava` to experiment with image understanding.

> [!TIP]
> Many models available for Ollama allow you to include images in the prompt, even if the model does not support image inputs. In that case, the images are silently removed from the input. This can result in unexpected outputs.

```matlab
chat = ollamaChat("llava");
image_path = "peppers.png";
messages = messageHistory;
messages = addUserMessageWithImages(messages,"What is in the image?",image_path);
[txt,response] = generate(chat,messages,MaxNumTokens=4096);
txt
% outputs a description of the image
```

## Establishing a connection to remote LLMs using Ollama

To connect to a remote Ollama server, use the `Endpoint` name-value pair. Include the server name and port number. Ollama starts on 11434 by default.
Expand Down
7 changes: 4 additions & 3 deletions doc/OpenAI.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,15 @@ You can extract the arguments and write the data to a table, for example.

## Understanding the content of an image

You can use gpt-4-turbo to experiment with image understanding.
You can use gpt-4o, gpt-4o-mini, or gpt-4-turbo to experiment with image understanding.
```matlab
chat = openAIChat("You are an AI assistant.", ModelName="gpt-4-turbo");
chat = openAIChat("You are an AI assistant.");
image_path = "peppers.png";
messages = messageHistory;
messages = addUserMessageWithImages(messages,"What is in the image?",image_path);
[txt,response] = generate(chat,messages,MaxNumTokens=4096);
% Should output the description of the image
txt
% outputs a description of the image
```

## Obtaining embeddings
Expand Down
29 changes: 3 additions & 26 deletions messageHistory.m
Original file line number Diff line number Diff line change
Expand Up @@ -111,32 +111,9 @@
nvp.Detail string {mustBeMember(nvp.Detail,["low","high","auto"])} = "auto"
end

newMessage = struct("role", "user", "content", []);
newMessage.content = {struct("type","text","text",string(content))};
for img = images(:).'
if startsWith(img,("https://"|"http://"))
s = struct( ...
"type","image_url", ...
"image_url",struct("url",img));
else
[~,~,ext] = fileparts(img);
MIMEType = "data:image/" + erase(ext,".") + ";base64,";
% Base64 encode the image using the given MIME type
fid = fopen(img);
im = fread(fid,'*uint8');
fclose(fid);
b64 = matlab.net.base64encode(im);
s = struct( ...
"type","image_url", ...
"image_url",struct("url",MIMEType + b64));
end

s.image_url.detail = nvp.Detail;

newMessage.content{end+1} = s;
this.Messages{end+1} = newMessage;
end

newMessage = struct("role", "user", "content", string(content), ...
"images", images, "image_detail", nvp.Detail);
this.Messages{end+1} = newMessage;
end

function this = addToolMessage(this, id, name, content)
Expand Down
24 changes: 23 additions & 1 deletion ollamaChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
if isstring(messages) && isscalar(messages)
messagesStruct = {struct("role", "user", "content", messages)};
else
messagesStruct = messages.Messages;
messagesStruct = this.encodeImages(messages.Messages);
end

if ~isempty(this.SystemPrompt)
Expand All @@ -160,6 +160,28 @@
end
end

methods (Access=private)
function messageStruct = encodeImages(~, messageStruct)
for k=1:numel(messageStruct)
if isfield(messageStruct{k},"images")
images = messageStruct{k}.images;
% detail = messageStruct{k}.image_detail;
messageStruct{k} = rmfield(messageStruct{k},["images","image_detail"]);
imgs = cell(size(images));
for n = 1:numel(images)
img = images(n);
% Base64 encode the image
fid = fopen(img);
im = fread(fid,'*uint8');
fclose(fid);
imgs{n} = matlab.net.base64encode(im);
end
messageStruct{k}.images = imgs;
end
end
end
end

methods(Static)
function mdls = models
%ollamaChat.models - return models available on Ollama server
Expand Down
36 changes: 35 additions & 1 deletion openAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@
if isstring(messages) && isscalar(messages)
messagesStruct = {struct("role", "user", "content", messages)};
else
messagesStruct = messages.Messages;
messagesStruct = this.encodeImages(messages.Messages);
end

llms.openai.validateMessageSupported(messagesStruct{end}, this.ModelName);
Expand Down Expand Up @@ -230,6 +230,40 @@ function mustBeValidFunctionCall(this, functionCall)
end

end

function messageStruct = encodeImages(~, messageStruct)
for k=1:numel(messageStruct)
if isfield(messageStruct{k},"images")
images = messageStruct{k}.images;
detail = messageStruct{k}.image_detail;
messageStruct{k} = rmfield(messageStruct{k},["images","image_detail"]);
messageStruct{k}.content = ...
{struct("type","text","text",messageStruct{k}.content)};
for img = images(:).'
if startsWith(img,("https://"|"http://"))
s = struct( ...
"type","image_url", ...
"image_url",struct("url",img));
else
[~,~,ext] = fileparts(img);
MIMEType = "data:image/" + erase(ext,".") + ";base64,";
% Base64 encode the image using the given MIME type
fid = fopen(img);
im = fread(fid,'*uint8');
fclose(fid);
b64 = matlab.net.base64encode(im);
s = struct( ...
"type","image_url", ...
"image_url",struct("url",MIMEType + b64));
end

s.image_url.detail = detail;

messageStruct{k}.content{end+1} = s;
end
end
end
end
end
end

Expand Down
29 changes: 29 additions & 0 deletions tests/tazureChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,26 @@ function generateMultipleResponses(testCase)
testCase.verifySize(response.Body.Data.choices,[3,1]);
end

function generateWithImage(testCase)
chat = azureChat(Deployment="gpt-4o");
image_path = "peppers.png";
emptyMessages = messageHistory;
messages = addUserMessageWithImages(emptyMessages,"What is in the image?",image_path);

text = generate(chat,messages);
testCase.verifyThat(text,matlab.unittest.constraints.ContainsSubstring("pepper"));
end

function generateWithMultipleImages(testCase)
import matlab.unittest.constraints.ContainsSubstring
chat = azureChat(Deployment="gpt-4o");
image_path = "peppers.png";
emptyMessages = messageHistory;
messages = addUserMessageWithImages(emptyMessages,"Compare these images.",[image_path,image_path]);

text = generate(chat,messages);
testCase.verifyThat(text,ContainsSubstring("same") | ContainsSubstring("identical"));
end

function doReturnErrors(testCase)
testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
Expand All @@ -65,6 +85,15 @@ function doReturnErrors(testCase)
testCase.verifyError(@() generate(chat,wayTooLong), "llms:apiReturnedError");
end

function generateWithImageErrorsForGpt35(testCase)
chat = azureChat;
image_path = "peppers.png";
emptyMessages = messageHistory;
messages = addUserMessageWithImages(emptyMessages,"What is in the image?",image_path);

testCase.verifyError(@() generate(chat,messages), "llms:apiReturnedError");
end

function seedFixesResult(testCase)
testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
chat = azureChat;
Expand Down
10 changes: 10 additions & 0 deletions tests/tollamaChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ function seedFixesResult(testCase)
testCase.verifyEqual(response1,response2);
end

function generateWithImages(testCase)
chat = ollamaChat("bakllava");
image_path = "peppers.png";
emptyMessages = messageHistory;
messages = addUserMessageWithImages(emptyMessages,"What is in the image?",image_path);

text = generate(chat,messages);
testCase.verifyThat(text,matlab.unittest.constraints.ContainsSubstring("pepper"));
end

function streamFunc(testCase)
function seen = sf(str)
persistent data;
Expand Down
13 changes: 12 additions & 1 deletion tests/topenAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ function generateWithToolsAndStreamFunc(testCase)
testCase.verifyThat(data,HasField("explanation"));
end

function generateWithImages(testCase)
function generateWithImage(testCase)
chat = openAIChat;
image_path = "peppers.png";
emptyMessages = messageHistory;
Expand All @@ -183,6 +183,17 @@ function generateWithImages(testCase)
testCase.verifyThat(text,matlab.unittest.constraints.ContainsSubstring("pepper"));
end

function generateWithMultipleImages(testCase)
import matlab.unittest.constraints.ContainsSubstring
chat = openAIChat;
image_path = "peppers.png";
emptyMessages = messageHistory;
messages = addUserMessageWithImages(emptyMessages,"Compare these images.",[image_path,image_path]);

text = generate(chat,messages);
testCase.verifyThat(text,ContainsSubstring("same") | ContainsSubstring("identical"));
end

function invalidInputsGenerate(testCase, InvalidGenerateInput)
f = openAIFunction("validfunction");
chat = openAIChat(Tools=f, APIKey="this-is-not-a-real-key");
Expand Down

0 comments on commit c3bec04

Please sign in to comment.