Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix --n parameter and add parameter shorthands. Fixes #10 #11

Merged
merged 2 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions src/cloai/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,31 @@ async def image_generation( # noqa: PLR0913
prompt: The text to generate an image from.
output_base_name: The base name of the output file.
model: The model to use.
size: The size of the generated image. Defaults to None.
quality: The quality of the generated image. Defaults to "standard".
n: The number of images to generate. Defaults to 1.
size: The size of the generated image.
quality: The quality of the generated
image. Defaults to "standard".
n: The number of images to generate.

Returns:
bytes: The generated image as bytes.

Notes:
At present, the image generation API of dalle-3 only supports generating
one image at a time. Instead, we call the API once for each image we want
to generate.
"""
image_generation = openai_api.ImageGeneration()
urls = await image_generation.run(
prompt,
model=model,
size=size,
quality=quality,
n=n,
)

url_promises = [
image_generation.run(
prompt,
model=model,
size=size,
quality=quality,
n=1,
)
for _ in range(n)
]
urls = [url[0] for url in await asyncio.gather(*url_promises)]
for index, url in enumerate(urls):
if url is None:
logger.warning("Image %s failed to generate, skipping.", index)
Expand Down
12 changes: 8 additions & 4 deletions src/cloai/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def run_command(args: argparse.Namespace) -> str | bytes | None:
model=args.model,
size=args.size,
quality=args.quality,
n=args.n,
n=args.number,
)
return None
if args.command == "tts":
Expand Down Expand Up @@ -189,6 +189,7 @@ def _add_image_generation_parser(
type=str,
)
image_generation_parser.add_argument(
"-m",
"--model",
help=(
"The model to use. Consult OpenAI's documentation for an up-to-date list"
Expand All @@ -198,21 +199,24 @@ def _add_image_generation_parser(
default="dall-e-3",
)
image_generation_parser.add_argument(
"-s",
"--size",
help="The size of the generated image.",
type=lambda x: x.lower(),
choices=["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
default="1024x1024",
)
image_generation_parser.add_argument(
"-q",
"--quality",
help="The quality of the generated image.",
type=lambda x: x.lower(),
choices=["standard", "hd"],
default="standard",
)
image_generation_parser.add_argument(
"--n",
"-n",
"--number",
help="The number of images to generate.",
type=_positive_int,
default=1,
Expand All @@ -239,7 +243,7 @@ def _arg_validation(args: argparse.Namespace) -> argparse.Namespace:
return args


def _positive_int(value: int) -> int:
def _positive_int(value: int | str) -> int:
"""Ensures the value is a positive integer.

Args:
Expand All @@ -252,7 +256,7 @@ def _positive_int(value: int) -> int:
exceptions.InvalidArgumentError: If the value is not an integer or not a
positive integer.
"""
if int(value) != value:
if int(value) != float(value):
msg = f"{value} is not an integer."
raise exceptions.InvalidArgumentError(msg)
if int(value) <= 0:
Expand Down
39 changes: 36 additions & 3 deletions tests/unit/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test__add_image_generation_parser() -> None:
assert arguments[5].help == "The quality of the generated image."
assert arguments[5].default == "standard"

assert arguments[6].dest == "n"
assert arguments[6].dest == "number"
assert arguments[6].help == "The number of images to generate."
assert arguments[6].default == 1

Expand Down Expand Up @@ -203,7 +203,7 @@ async def test_run_command_with_dalle(mocker: pytest_mock.MockFixture) -> None:
"model": "dall-e-3",
"size": "1024x1024",
"quality": "standard",
"n": 1,
"number": 1,
}
args = argparse.Namespace(**arg_dict)
mock = mocker.patch("cloai.cli.commands.image_generation")
Expand All @@ -216,7 +216,7 @@ async def test_run_command_with_dalle(mocker: pytest_mock.MockFixture) -> None:
model=arg_dict["model"],
size=arg_dict["size"],
quality=arg_dict["quality"],
n=arg_dict["n"],
n=arg_dict["number"],
)


Expand Down Expand Up @@ -275,6 +275,39 @@ async def test_parse_args_with_command_no_other_arguments(
assert excinfo.value.code == expected_error_code


@pytest.mark.asyncio()
async def test_parse_args_from_cli_with_dalle_all_arguments(
mocker: pytest_mock.MockFixture,
) -> None:
"""Tests the parse_args function with the 'dalle' command and all arguments."""
command = mocker.patch("cloai.cli.commands.image_generation")
sys.argv = [
"cloai",
"dalle",
"test",
"test",
"--model",
"dall-e-3",
"--size",
"1024x1024",
"--quality",
"standard",
"-n",
"1",
]

await parser.parse_args()

command.assert_called_once_with(
prompt="test",
output_base_name="test",
model="dall-e-3",
size="1024x1024",
quality="standard",
n=1,
)


@pytest.mark.parametrize(
"size",
[
Expand Down