Skip to content

Commit

Permalink
Properly accept char and cellstr in extractOpenAIEmbeddings
Browse files Browse the repository at this point in the history
Fixes #85
  • Loading branch information
ccreutzi committed Jan 23, 2025
1 parent 3f312b5 commit 11a54ea
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 8 deletions.
13 changes: 7 additions & 6 deletions extractOpenAIEmbeddings.m
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,18 @@
% Copyright 2023-2024 The MathWorks, Inc.

arguments
text (1,:) {mustBeNonzeroLengthText}
nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["text-embedding-ada-002", ...
"text-embedding-3-large", "text-embedding-3-small"])} = "text-embedding-ada-002"
nvp.TimeOut (1,1) {mustBeNumeric,mustBeReal,mustBePositive} = 10
nvp.Dimensions (1,1) {mustBeNumeric,mustBeInteger,mustBePositive}
nvp.APIKey {llms.utils.mustBeNonzeroLengthTextScalar}
text (1,:) {mustBeNonzeroLengthText}
nvp.ModelName (1,1) string {mustBeMember(nvp.ModelName,["text-embedding-ada-002", ...
"text-embedding-3-large", "text-embedding-3-small"])} = "text-embedding-ada-002"
nvp.TimeOut (1,1) {mustBeNumeric,mustBeReal,mustBePositive} = 10
nvp.Dimensions (1,1) {mustBeNumeric,mustBeInteger,mustBePositive}
nvp.APIKey {llms.utils.mustBeNonzeroLengthTextScalar}
end

END_POINT = "https://api.openai.com/v1/embeddings";

key = llms.internal.getApiKeyFromNvpOrEnv(nvp,"OPENAI_API_KEY");
text = convertCharsToStrings(text);

parameters = struct("input",text,"model",nvp.ModelName);

Expand Down
31 changes: 29 additions & 2 deletions tests/textractOpenAIEmbeddings.m
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

properties(TestParameter)
InvalidInput = iGetInvalidInput();
ValidInput = iGetValidInput();
ValidDimensionsModelCombinations = iGetValidDimensionsModelCombinations();
end

Expand Down Expand Up @@ -34,8 +35,9 @@ function validCombinationOfModelAndDimension(testCase, ValidDimensionsModelCombi
APIKey="not-real"));
end

function embedStringWithSuccessfulOpenAICall(testCase)
testCase.verifyWarningFree(@()extractOpenAIEmbeddings("bla"));
function embedTextWithSuccessfulOpenAICall(testCase,ValidInput)
result = testCase.verifyWarningFree(@()extractOpenAIEmbeddings(ValidInput.Input{:}));
testCase.verifySize(result, ValidInput.ExpectedSize);
end

function invalidCombinationOfModelAndDimension(testCase)
Expand All @@ -57,6 +59,31 @@ function testInvalidInputs(testCase, InvalidInput)
end
end

function validInput = iGetValidInput()
validInput = struct( ...
"ScalarString", struct( ...
"Input",{{ "blah" }}, ...
"ExpectedSize",[1,1536]), ...
"StringVector", struct( ...
"Input",{{ ["a", "b", "c"] }}, ...
"ExpectedSize",[3,1536]), ...
"CharVector", struct( ...
"Input", {{ 'foo' }}, ...
"ExpectedSize",[1,1536]), ...
"Cellstr", struct( ...
"Input",{{ {'cat', 'dog', 'mouse'} }}, ...
"ExpectedSize",[3,1536]), ...
"ModelAsString", struct( ...
"Input",{{ "foo","ModelName","text-embedding-3-small" }}, ...
"ExpectedSize",[1,1536]), ...
"ModelAsChar", struct( ...
"Input",{{ "foo","ModelName",'text-embedding-3-small' }}, ...
"ExpectedSize",[1,1536]), ...
"ModelAsCellstr", struct( ...
"Input",{{ "foo","ModelName",{'text-embedding-3-small'} }}, ...
"ExpectedSize",[1,1536]));
end

function invalidInput = iGetInvalidInput()
invalidInput = struct( ...
"InvalidEmptyText", struct( ...
Expand Down

0 comments on commit 11a54ea

Please sign in to comment.