-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from matlab-deep-learning/AzureAPI
Adding support to Azure API
- Loading branch information
Showing
50 changed files
with
3,671 additions
and
1,397 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
function versions = apiVersions | ||
%VERSIONS - supported azure API versions | ||
|
||
% Copyright 2024 The MathWorks, Inc. | ||
versions = [... | ||
"2024-05-01-preview", ... | ||
"2024-04-01-preview", ... | ||
"2024-03-01-preview", ... | ||
"2024-02-01", ... | ||
"2023-05-15", ... | ||
]; | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
function [text, message, response] = callAzureChatAPI(endpoint, deploymentID, messages, functions, nvp) | ||
% This function is undocumented and will change in a future release | ||
|
||
%callAzureChatAPI Calls the openAI chat completions API on Azure. | ||
% | ||
% MESSAGES and FUNCTIONS should be structs matching the json format | ||
% required by the OpenAI Chat Completions API. | ||
% Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api | ||
% | ||
% More details on the parameters: https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/chatgpt | ||
% | ||
% Example | ||
% | ||
% % Create messages struct | ||
% messages = {struct("role", "system",... | ||
% "content", "You are a helpful assistant"); | ||
% struct("role", "user", ... | ||
% "content", "What is the edit distance between hi and hello?")}; | ||
% | ||
% % Create functions struct | ||
% functions = {struct("name", "editDistance", ... | ||
% "description", "Find edit distance between two strings or documents.", ... | ||
% "parameters", struct( ... | ||
% "type", "object", ... | ||
% "properties", struct(... | ||
% "str1", struct(... | ||
% "description", "Source string.", ... | ||
% "type", "string"),... | ||
% "str2", struct(... | ||
% "description", "Target string.", ... | ||
% "type", "string")),... | ||
% "required", ["str1", "str2"]))}; | ||
% | ||
% % Define your API key | ||
% apiKey = "your-api-key-here" | ||
% | ||
% % Send a request | ||
% [text, message] = llms.internal.callAzureChatAPI(messages, functions, APIKey=apiKey) | ||
|
||
% Copyright 2023-2024 The MathWorks, Inc. | ||
|
||
arguments | ||
endpoint | ||
deploymentID | ||
messages | ||
functions | ||
nvp.ToolChoice | ||
nvp.APIVersion | ||
nvp.Temperature | ||
nvp.TopP | ||
nvp.NumCompletions | ||
nvp.StopSequences | ||
nvp.MaxNumTokens | ||
nvp.PresencePenalty | ||
nvp.FrequencyPenalty | ||
nvp.ResponseFormat | ||
nvp.Seed | ||
nvp.APIKey | ||
nvp.TimeOut | ||
nvp.StreamFun | ||
end | ||
|
||
URL = endpoint + "openai/deployments/" + deploymentID + "/chat/completions?api-version=" + nvp.APIVersion; | ||
|
||
parameters = buildParametersCall(messages, functions, nvp); | ||
|
||
[response, streamedText] = llms.internal.sendRequest(parameters,nvp.APIKey, URL, nvp.TimeOut, nvp.StreamFun); | ||
|
||
% If call errors, "choices" will not be part of response.Body.Data, instead | ||
% we get response.Body.Data.error | ||
if response.StatusCode=="OK" | ||
% Outputs the first generation | ||
if isempty(nvp.StreamFun) | ||
message = response.Body.Data.choices(1).message; | ||
else | ||
message = struct("role", "assistant", ... | ||
"content", streamedText); | ||
end | ||
if isfield(message, "tool_choice") | ||
text = ""; | ||
else | ||
text = string(message.content); | ||
end | ||
else | ||
text = ""; | ||
message = struct(); | ||
end | ||
end | ||
|
||
function parameters = buildParametersCall(messages, functions, nvp) | ||
% Builds a struct in the format that is expected by the API, combining | ||
% MESSAGES, FUNCTIONS and parameters in NVP. | ||
|
||
parameters = struct(); | ||
parameters.messages = messages; | ||
|
||
parameters.stream = ~isempty(nvp.StreamFun); | ||
|
||
if ~isempty(functions) | ||
parameters.tools = functions; | ||
end | ||
|
||
if ~isempty(nvp.ToolChoice) | ||
parameters.tool_choice = nvp.ToolChoice; | ||
end | ||
|
||
if ~isempty(nvp.Seed) | ||
parameters.seed = nvp.Seed; | ||
end | ||
|
||
dict = mapNVPToParameters; | ||
|
||
nvpOptions = keys(dict); | ||
for opt = nvpOptions.' | ||
if isfield(nvp, opt) | ||
parameters.(dict(opt)) = nvp.(opt); | ||
end | ||
end | ||
end | ||
|
||
function dict = mapNVPToParameters() | ||
dict = dictionary(); | ||
dict("Temperature") = "temperature"; | ||
dict("TopP") = "top_p"; | ||
dict("NumCompletions") = "n"; | ||
dict("StopSequences") = "stop"; | ||
dict("MaxNumTokens") = "max_tokens"; | ||
dict("PresencePenalty") = "presence_penalty"; | ||
dict("FrequencyPenalty") = "frequency_penalty"; | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
function [text, message, response] = callOllamaChatAPI(model, messages, nvp) | ||
% This function is undocumented and will change in a future release | ||
|
||
%callOllamaChatAPI Calls the Ollama® chat completions API. | ||
% | ||
% MESSAGES and FUNCTIONS should be structs matching the json format | ||
% required by the Ollama Chat Completions API. | ||
% Ref: https://github.com/ollama/ollama/blob/main/docs/api.md | ||
% | ||
% More details on the parameters: https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values | ||
% | ||
% Example | ||
% | ||
% model = "mistral"; | ||
% | ||
% % Create messages struct | ||
% messages = {struct("role", "system",... | ||
% "content", "You are a helpful assistant"); | ||
% struct("role", "user", ... | ||
% "content", "What is the edit distance between hi and hello?")}; | ||
% | ||
% % Send a request | ||
% [text, message] = llms.internal.callOllamaChatAPI(model, messages) | ||
|
||
% Copyright 2023-2024 The MathWorks, Inc. | ||
|
||
arguments | ||
model | ||
messages | ||
nvp.Temperature | ||
nvp.TopP | ||
nvp.TopK | ||
nvp.TailFreeSamplingZ | ||
nvp.StopSequences | ||
nvp.MaxNumTokens | ||
nvp.ResponseFormat | ||
nvp.Seed | ||
nvp.TimeOut | ||
nvp.StreamFun | ||
end | ||
|
||
URL = "http://localhost:11434/api/chat"; | ||
|
||
% The JSON for StopSequences must have an array, and cannot say "stop": "foo". | ||
% The easiest way to ensure that is to never pass in a scalar … | ||
if isscalar(nvp.StopSequences) | ||
nvp.StopSequences = [nvp.StopSequences, nvp.StopSequences]; | ||
end | ||
|
||
parameters = buildParametersCall(model, messages, nvp); | ||
|
||
[response, streamedText] = llms.internal.sendRequest(parameters,[],URL,nvp.TimeOut,nvp.StreamFun); | ||
|
||
% If call errors, "choices" will not be part of response.Body.Data, instead | ||
% we get response.Body.Data.error | ||
if response.StatusCode=="OK" | ||
% Outputs the first generation | ||
if isempty(nvp.StreamFun) | ||
message = response.Body.Data.message; | ||
else | ||
message = struct("role", "assistant", ... | ||
"content", streamedText); | ||
end | ||
text = string(message.content); | ||
else | ||
text = ""; | ||
message = struct(); | ||
end | ||
end | ||
|
||
function parameters = buildParametersCall(model, messages, nvp) | ||
% Builds a struct in the format that is expected by the API, combining | ||
% MESSAGES, FUNCTIONS and parameters in NVP. | ||
|
||
parameters = struct(); | ||
parameters.model = model; | ||
parameters.messages = messages; | ||
|
||
parameters.stream = ~isempty(nvp.StreamFun); | ||
|
||
options = struct; | ||
if ~isempty(nvp.Seed) | ||
options.seed = nvp.Seed; | ||
end | ||
|
||
dict = mapNVPToParameters; | ||
|
||
nvpOptions = keys(dict); | ||
for opt = nvpOptions.' | ||
if isfield(nvp, opt) && ~isempty(nvp.(opt)) && ~isequaln(nvp.(opt),Inf) | ||
options.(dict(opt)) = nvp.(opt); | ||
end | ||
end | ||
|
||
parameters.options = options; | ||
end | ||
|
||
function dict = mapNVPToParameters() | ||
dict = dictionary(); | ||
dict("Temperature") = "temperature"; | ||
dict("TopP") = "top_p"; | ||
dict("TopK") = "top_k"; | ||
dict("TailFreeSamplingZ") = "tfs_z"; | ||
dict("StopSequences") = "stop"; | ||
dict("MaxNumTokens") = "num_predict"; | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,23 @@ | ||
function key = getApiKeyFromNvpOrEnv(nvp) | ||
function key = getApiKeyFromNvpOrEnv(nvp,envVarName) | ||
% This function is undocumented and will change in a future release | ||
|
||
%getApiKeyFromNvpOrEnv Retrieves an API key from a Name-Value Pair struct or environment variable. | ||
% | ||
% This function takes a struct nvp containing name-value pairs and checks | ||
% if it contains a field called "ApiKey". If the field is not found, | ||
% the function attempts to retrieve the API key from an environment | ||
% variable called "OPENAI_API_KEY". If both methods fail, the function | ||
% throws an error. | ||
% This function takes a struct nvp containing name-value pairs and checks if | ||
% it contains a field called "APIKey". If the field is not found, the | ||
% function attempts to retrieve the API key from an environment variable | ||
% whose name is given as the second argument. If both methods fail, the | ||
% function throws an error. | ||
|
||
% Copyright 2023 The MathWorks, Inc. | ||
% Copyright 2023-2024 The MathWorks, Inc. | ||
|
||
if isfield(nvp, "ApiKey") | ||
key = nvp.ApiKey; | ||
if isfield(nvp, "APIKey") | ||
key = nvp.APIKey; | ||
else | ||
if isenv("OPENAI_API_KEY") | ||
key = getenv("OPENAI_API_KEY"); | ||
if isenv(envVarName) | ||
key = getenv(envVarName); | ||
else | ||
error("llms:keyMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:keyMustBeSpecified")); | ||
error("llms:keyMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:keyMustBeSpecified", envVarName)); | ||
end | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
classdef (Abstract) gptPenalties | ||
% This class is undocumented and will change in a future release | ||
|
||
% Copyright 2024 The MathWorks, Inc. | ||
properties | ||
%PRESENCEPENALTY Penalty for using a token in the response that has already been used. | ||
PresencePenalty {llms.utils.mustBeValidPenalty} = 0 | ||
|
||
%FREQUENCYPENALTY Penalty for using a token that is frequent in the training data. | ||
FrequencyPenalty {llms.utils.mustBeValidPenalty} = 0 | ||
end | ||
end |
Oops, something went wrong.