diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 70a7d6c..d7babcc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -30,6 +30,7 @@ jobs: - name: Pull models run: | ollama pull mistral + ollama pull bakllava OLLAMA_HOST=127.0.0.1:11435 ollama pull qwen2:0.5b - name: Set up MATLAB uses: matlab-actions/setup-matlab@v2 diff --git a/azureChat.m b/azureChat.m index 27bdded..b4b8df5 100644 --- a/azureChat.m +++ b/azureChat.m @@ -191,7 +191,7 @@ if isstring(messages) && isscalar(messages) messagesStruct = {struct("role", "user", "content", messages)}; else - messagesStruct = messages.Messages; + messagesStruct = this.encodeImages(messages.Messages); end if ~isempty(this.SystemPrompt) @@ -251,6 +251,40 @@ function mustBeValidFunctionCall(this, functionCall) end end + + function messageStruct = encodeImages(~, messageStruct) + for k=1:numel(messageStruct) + if isfield(messageStruct{k},"images") + images = messageStruct{k}.images; + detail = messageStruct{k}.image_detail; + messageStruct{k} = rmfield(messageStruct{k},["images","image_detail"]); + messageStruct{k}.content = ... + {struct("type","text","text",messageStruct{k}.content)}; + for img = images(:).' + if startsWith(img,("https://"|"http://")) + s = struct( ... + "type","image_url", ... + "image_url",struct("url",img)); + else + [~,~,ext] = fileparts(img); + MIMEType = "data:image/" + erase(ext,".") + ";base64,"; + % Base64 encode the image using the given MIME type + fid = fopen(img); + im = fread(fid,'*uint8'); + fclose(fid); + b64 = matlab.net.base64encode(im); + s = struct( ... + "type","image_url", ... + "image_url",struct("url",MIMEType + b64)); + end + + s.image_url.detail = detail; + + messageStruct{k}.content{end+1} = s; + end + end + end + end end end diff --git a/doc/Azure.md b/doc/Azure.md index d5af221..382c3bf 100644 --- a/doc/Azure.md +++ b/doc/Azure.md @@ -115,6 +115,19 @@ txt = generate(chat,"What is Model-Based Design and how is it related to Digital % Should stream the response token by token ``` +## Understanding the content of an image + +You can use gpt-4o, gpt-4o-mini, or gpt-4-turbo to experiment with image understanding. +```matlab +chat = azureChat("You are an AI assistant.",Deployment="gpt-4o"); +image_path = "peppers.png"; +messages = messageHistory; +messages = addUserMessageWithImages(messages,"What is in the image?",image_path); +[txt,response] = generate(chat,messages,MaxNumTokens=4096); +txt +% outputs a description of the image +``` + ## Calling MATLAB functions with the API Optionally, `Tools=functions` can be used to provide function specifications to the API. The purpose of this is to enable models to generate function arguments which adhere to the provided specifications. diff --git a/doc/Ollama.md b/doc/Ollama.md index 110a869..ea44d84 100644 --- a/doc/Ollama.md +++ b/doc/Ollama.md @@ -96,6 +96,24 @@ txt = generate(chat,"What is Model-Based Design and how is it related to Digital % Should stream the response token by token ``` +## Understanding the content of an image + +You can use multimodal models like `llava` to experiment with image understanding. + +> [!TIP] +> Many models available for Ollama allow you to include images in the prompt, even if the model does not support image inputs. In that case, the images are silently removed from the input. This can result in unexpected outputs. + + +```matlab +chat = ollamaChat("llava"); +image_path = "peppers.png"; +messages = messageHistory; +messages = addUserMessageWithImages(messages,"What is in the image?",image_path); +[txt,response] = generate(chat,messages,MaxNumTokens=4096); +txt +% outputs a description of the image +``` + ## Establishing a connection to remote LLMs using Ollama To connect to a remote Ollama server, use the `Endpoint` name-value pair. Include the server name and port number. Ollama starts on 11434 by default. diff --git a/doc/OpenAI.md b/doc/OpenAI.md index 51783ae..61ac83e 100644 --- a/doc/OpenAI.md +++ b/doc/OpenAI.md @@ -250,14 +250,15 @@ You can extract the arguments and write the data to a table, for example. ## Understanding the content of an image -You can use gpt-4-turbo to experiment with image understanding. +You can use gpt-4o, gpt-4o-mini, or gpt-4-turbo to experiment with image understanding. ```matlab -chat = openAIChat("You are an AI assistant.", ModelName="gpt-4-turbo"); +chat = openAIChat("You are an AI assistant."); image_path = "peppers.png"; messages = messageHistory; messages = addUserMessageWithImages(messages,"What is in the image?",image_path); [txt,response] = generate(chat,messages,MaxNumTokens=4096); -% Should output the description of the image +txt +% outputs a description of the image ``` ## Obtaining embeddings diff --git a/messageHistory.m b/messageHistory.m index 74c5e63..e2e0e60 100644 --- a/messageHistory.m +++ b/messageHistory.m @@ -111,32 +111,9 @@ nvp.Detail string {mustBeMember(nvp.Detail,["low","high","auto"])} = "auto" end - newMessage = struct("role", "user", "content", []); - newMessage.content = {struct("type","text","text",string(content))}; - for img = images(:).' - if startsWith(img,("https://"|"http://")) - s = struct( ... - "type","image_url", ... - "image_url",struct("url",img)); - else - [~,~,ext] = fileparts(img); - MIMEType = "data:image/" + erase(ext,".") + ";base64,"; - % Base64 encode the image using the given MIME type - fid = fopen(img); - im = fread(fid,'*uint8'); - fclose(fid); - b64 = matlab.net.base64encode(im); - s = struct( ... - "type","image_url", ... - "image_url",struct("url",MIMEType + b64)); - end - - s.image_url.detail = nvp.Detail; - - newMessage.content{end+1} = s; - this.Messages{end+1} = newMessage; - end - + newMessage = struct("role", "user", "content", string(content), ... + "images", images, "image_detail", nvp.Detail); + this.Messages{end+1} = newMessage; end function this = addToolMessage(this, id, name, content) diff --git a/ollamaChat.m b/ollamaChat.m index cdb2e35..6d9e5a0 100644 --- a/ollamaChat.m +++ b/ollamaChat.m @@ -136,7 +136,7 @@ if isstring(messages) && isscalar(messages) messagesStruct = {struct("role", "user", "content", messages)}; else - messagesStruct = messages.Messages; + messagesStruct = this.encodeImages(messages.Messages); end if ~isempty(this.SystemPrompt) @@ -160,6 +160,28 @@ end end + methods (Access=private) + function messageStruct = encodeImages(~, messageStruct) + for k=1:numel(messageStruct) + if isfield(messageStruct{k},"images") + images = messageStruct{k}.images; + % detail = messageStruct{k}.image_detail; + messageStruct{k} = rmfield(messageStruct{k},["images","image_detail"]); + imgs = cell(size(images)); + for n = 1:numel(images) + img = images(n); + % Base64 encode the image + fid = fopen(img); + im = fread(fid,'*uint8'); + fclose(fid); + imgs{n} = matlab.net.base64encode(im); + end + messageStruct{k}.images = imgs; + end + end + end + end + methods(Static) function mdls = models %ollamaChat.models - return models available on Ollama server diff --git a/openAIChat.m b/openAIChat.m index 7d73c75..cbd2440 100644 --- a/openAIChat.m +++ b/openAIChat.m @@ -181,7 +181,7 @@ if isstring(messages) && isscalar(messages) messagesStruct = {struct("role", "user", "content", messages)}; else - messagesStruct = messages.Messages; + messagesStruct = this.encodeImages(messages.Messages); end llms.openai.validateMessageSupported(messagesStruct{end}, this.ModelName); @@ -230,6 +230,40 @@ function mustBeValidFunctionCall(this, functionCall) end end + + function messageStruct = encodeImages(~, messageStruct) + for k=1:numel(messageStruct) + if isfield(messageStruct{k},"images") + images = messageStruct{k}.images; + detail = messageStruct{k}.image_detail; + messageStruct{k} = rmfield(messageStruct{k},["images","image_detail"]); + messageStruct{k}.content = ... + {struct("type","text","text",messageStruct{k}.content)}; + for img = images(:).' + if startsWith(img,("https://"|"http://")) + s = struct( ... + "type","image_url", ... + "image_url",struct("url",img)); + else + [~,~,ext] = fileparts(img); + MIMEType = "data:image/" + erase(ext,".") + ";base64,"; + % Base64 encode the image using the given MIME type + fid = fopen(img); + im = fread(fid,'*uint8'); + fclose(fid); + b64 = matlab.net.base64encode(im); + s = struct( ... + "type","image_url", ... + "image_url",struct("url",MIMEType + b64)); + end + + s.image_url.detail = detail; + + messageStruct{k}.content{end+1} = s; + end + end + end + end end end diff --git a/tests/tazureChat.m b/tests/tazureChat.m index 531936f..a16a5da 100644 --- a/tests/tazureChat.m +++ b/tests/tazureChat.m @@ -55,6 +55,26 @@ function generateMultipleResponses(testCase) testCase.verifySize(response.Body.Data.choices,[3,1]); end + function generateWithImage(testCase) + chat = azureChat(Deployment="gpt-4o"); + image_path = "peppers.png"; + emptyMessages = messageHistory; + messages = addUserMessageWithImages(emptyMessages,"What is in the image?",image_path); + + text = generate(chat,messages); + testCase.verifyThat(text,matlab.unittest.constraints.ContainsSubstring("pepper")); + end + + function generateWithMultipleImages(testCase) + import matlab.unittest.constraints.ContainsSubstring + chat = azureChat(Deployment="gpt-4o"); + image_path = "peppers.png"; + emptyMessages = messageHistory; + messages = addUserMessageWithImages(emptyMessages,"Compare these images.",[image_path,image_path]); + + text = generate(chat,messages); + testCase.verifyThat(text,ContainsSubstring("same") | ContainsSubstring("identical")); + end function doReturnErrors(testCase) testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT."); @@ -65,6 +85,15 @@ function doReturnErrors(testCase) testCase.verifyError(@() generate(chat,wayTooLong), "llms:apiReturnedError"); end + function generateWithImageErrorsForGpt35(testCase) + chat = azureChat; + image_path = "peppers.png"; + emptyMessages = messageHistory; + messages = addUserMessageWithImages(emptyMessages,"What is in the image?",image_path); + + testCase.verifyError(@() generate(chat,messages), "llms:apiReturnedError"); + end + function seedFixesResult(testCase) testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT."); chat = azureChat; diff --git a/tests/tollamaChat.m b/tests/tollamaChat.m index 95b9c3f..1040c4a 100644 --- a/tests/tollamaChat.m +++ b/tests/tollamaChat.m @@ -98,6 +98,16 @@ function seedFixesResult(testCase) testCase.verifyEqual(response1,response2); end + function generateWithImages(testCase) + chat = ollamaChat("bakllava"); + image_path = "peppers.png"; + emptyMessages = messageHistory; + messages = addUserMessageWithImages(emptyMessages,"What is in the image?",image_path); + + text = generate(chat,messages); + testCase.verifyThat(text,matlab.unittest.constraints.ContainsSubstring("pepper")); + end + function streamFunc(testCase) function seen = sf(str) persistent data; diff --git a/tests/topenAIChat.m b/tests/topenAIChat.m index 143335c..e06db55 100644 --- a/tests/topenAIChat.m +++ b/tests/topenAIChat.m @@ -173,7 +173,7 @@ function generateWithToolsAndStreamFunc(testCase) testCase.verifyThat(data,HasField("explanation")); end - function generateWithImages(testCase) + function generateWithImage(testCase) chat = openAIChat; image_path = "peppers.png"; emptyMessages = messageHistory; @@ -183,6 +183,17 @@ function generateWithImages(testCase) testCase.verifyThat(text,matlab.unittest.constraints.ContainsSubstring("pepper")); end + function generateWithMultipleImages(testCase) + import matlab.unittest.constraints.ContainsSubstring + chat = openAIChat; + image_path = "peppers.png"; + emptyMessages = messageHistory; + messages = addUserMessageWithImages(emptyMessages,"Compare these images.",[image_path,image_path]); + + text = generate(chat,messages); + testCase.verifyThat(text,ContainsSubstring("same") | ContainsSubstring("identical")); + end + function invalidInputsGenerate(testCase, InvalidGenerateInput) f = openAIFunction("validfunction"); chat = openAIChat(Tools=f, APIKey="this-is-not-a-real-key");