Skip to content

Commit

Permalink
Merge pull request #43 from matlab-deep-learning/size-check
Browse files Browse the repository at this point in the history
move size check into arguments block
  • Loading branch information
ccreutzi authored Jun 17, 2024
2 parents 9ea311b + a6c86c8 commit 38edd99
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 23 deletions.
29 changes: 8 additions & 21 deletions openAIImages.m
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@
mustBeLessThanOrEqual(nvp.NumImages,10)} = 1
nvp.Size (1,1) string {mustBeMember(nvp.Size, ["256x256", "512x512", ...
"1024x1024", "1792x1024", ...
"1024x1792"])} = "1024x1024"
"1024x1792"]), ...
mustBeValidSize(this,nvp.Size)} = "1024x1024"
nvp.Quality (1,1) string {mustBeMember(nvp.Quality,["standard", "hd"])}
nvp.Style (1,1) string {mustBeMember(nvp.Style,["vivid", "natural"])}
end

endpoint = "https://api.openai.com/v1/images/generations";

validatePromptSize(this.ModelName, prompt)
validateSizeNVP(this.ModelName, nvp.Size)

params = struct("prompt",prompt,...
"model",this.ModelName,...
Expand Down Expand Up @@ -180,9 +180,9 @@
nvp.MaskImagePath {mustBeValidFileType(nvp.MaskImagePath)}
nvp.NumImages (1,1) {mustBePositive, mustBeInteger,...
mustBeLessThanOrEqual(nvp.NumImages,10)} = 1
nvp.Size (1,1) string {mustBeMember(nvp.Size,["256x256", ...
"512x512", ...
"1024x1024"])} = "1024x1024"
nvp.Size (1,1) string {mustBeMember(nvp.Size,...
["256x256", "512x512","1024x1024"]), ...
mustBeValidSize(this,nvp.Size)} = "1024x1024"
end

% For now, this is only supported for "dall-e-2"
Expand Down Expand Up @@ -241,8 +241,9 @@
imagePath {mustBeValidFileType(imagePath)}
nvp.NumImages (1,1) {mustBePositive, mustBeInteger,...
mustBeLessThanOrEqual(nvp.NumImages,10)} = 1
nvp.Size (1,1) string {mustBeMember(nvp.Size,["256x256", ...
"512x512","1024x1024"])} = "1024x1024"
nvp.Size (1,1) string {mustBeMember(nvp.Size,...
["256x256", "512x512","1024x1024"]), ...
mustBeValidSize(this,nvp.Size)} = "1024x1024"
end

% For now, this is only supported for "dall-e-2"
Expand Down Expand Up @@ -308,20 +309,6 @@ function mustBeValidSize(this, imagesize)
end
end

function validateSizeNVP(model, size)
if ismember(size,["1792x1024", "1024x1792"]) && model=="dall-e-2"
error("llms:invalidOptionAndValueForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionAndValueForModel", ...
"Size", size, model));
end

if ismember(size,["256x256", "512x512"]) && model=="dall-e-3"
error("llms:invalidOptionAndValueForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionAndValueForModel", ...
"Size", size, model));
end
end

function validatePromptSize(model, prompt)
numChars = numel(char(prompt));
if model=="dall-e-3"
Expand Down
8 changes: 6 additions & 2 deletions tests/topenAIImages.m
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ function invalidOptionsGenerate(testCase)
mdl = openAIImages(ApiKey="this-is-not-a-real-key", Model="dall-e-2");
testCase.verifyError(@()generate(mdl, "cat", Quality="hd"), "llms:invalidOptionForModel")
testCase.verifyError(@()generate(mdl, "cat", Style="natural"), "llms:invalidOptionForModel")
testCase.verifyError(@()generate(mdl, "cat", Size="1024x1792"), "llms:invalidOptionAndValueForModel")
testCase.verifyError(@()generate(mdl, "cat", Size="1024x1792"), "MATLAB:validators:mustBeMember")
mdl = openAIImages(ApiKey="this-is-not-a-real-key", Model="dall-e-3");
testCase.verifyError(@()generate(mdl, "cat", Size="256x256"), "llms:invalidOptionAndValueForModel")
testCase.verifyError(@()generate(mdl, "cat", Size="256x256"), "MATLAB:validators:mustBeMember")
testCase.verifyError(@()generate(mdl, "cat", NumImages=4), "llms:invalidOptionAndValueForModel")
end

Expand Down Expand Up @@ -201,6 +201,10 @@ function testThatImageIsReturned(testCase)
"Input",{{ "prompt" "Size" "foo" }},...
"Error","MATLAB:validators:mustBeMember"),...
...
"InvalidSizeForModel",struct( ...
"Input",{{ "prompt" "Size" "1792x1024" }},...
"Error","MATLAB:validators:mustBeMember"),...
...
"InvalidQualityOption",struct( ...
"Input",{{ "prompt" "Quality" "foo" }},...
"Error","MATLAB:validators:mustBeMember"),...
Expand Down

0 comments on commit 38edd99

Please sign in to comment.