Skip to content

Commit

Permalink
Merge pull request #40 from childmindresearch/embeddings
Browse files Browse the repository at this point in the history
Added embedding model
  • Loading branch information
maya-roberts authored Feb 16, 2024
2 parents bf45dcc + 319f69d commit 58d18ac
Show file tree
Hide file tree
Showing 11 changed files with 622 additions and 334 deletions.
708 changes: 375 additions & 333 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pypdf = ">=3.17.3,<5.0.0"
aiohttp = "^3.9.3"
aiofiles = "^23.2.1"
instructor = "^0.5.2"
aiocsv = "^1.2.5"

[tool.poetry.group.dev.dependencies]
pytest = ">=7.4.3,<9.0.0"
Expand Down
34 changes: 34 additions & 0 deletions src/cloai/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tempfile
from typing import Literal

import aiofiles
import ffmpeg
import yaml

Expand Down Expand Up @@ -267,3 +268,36 @@ async def image_generation( # noqa: PLR0913
for index, url in enumerate(urls_not_none)
],
)


async def get_embedding(
text_file: pathlib.Path,
output_file: pathlib.Path,
model: Literal[
"text-embedding-3-small",
"text-embedding-3-large",
] = "text-embedding-3-large",
*,
keep_new_lines: bool = False,
) -> None:
"""Get the embedding using OpenAI's Embedding models.
Args:
text_file: the text file to embed.
model: the name of the Embedding model to use,
defaults to text-embedding-3-large.
output_file: the name of the CSV output file.
keep_new_lines: Whether to keep or remove line breaks,
defaults to False.
"""
async with aiofiles.open(text_file, mode="r") as file:
text = await file.read()

get_embedding = openai_api.Embedding()

embedding = await get_embedding.run(
text=text,
model=model,
keep_new_lines=keep_new_lines,
)
await utils.save_csv(output_file, embedding)
54 changes: 53 additions & 1 deletion src/cloai/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def create_parser() -> argparse.ArgumentParser:
_add_stt_parser(subparsers)
_add_tts_parser(subparsers)
_add_image_generation_parser(subparsers)
_add_embedding_parser(subparsers)
return parser


Expand Down Expand Up @@ -120,6 +121,14 @@ async def run_command(args: argparse.Namespace) -> str | bytes | None:
clip=args.clip,
language=config.WhisperLanguages[args.language],
)
if args.command == "embedding":
await commands.get_embedding(
text_file=args.text_file,
output_file=args.output_file,
model=args.model,
keep_new_lines=args.keep_new_lines,
)
return None
msg = f"Unknown command {args.command}."
raise exceptions.InvalidArgumentError(msg)

Expand All @@ -144,7 +153,7 @@ def _add_chat_completion_parser(

user_group = chat_parser.add_argument_group(
"User Prompts",
"""The prompts povided by the user. One must be provided and these
"""The prompts provided by the user. One must be provided and these
arguments are mutually exclusive.""",
)
user_group_exclusive = user_group.add_mutually_exclusive_group(required=True)
Expand Down Expand Up @@ -341,6 +350,49 @@ def _add_image_generation_parser(
)


def _add_embedding_parser(
subparsers: argparse._SubParsersAction,
) -> None:
"""Get the argument parser for the "embedding" command.
Args:
subparsers: The subparsers object to add the "embedding" command to.
Returns:
argparse.ArgumentParser: The argument parser for the "embedding" command.
"""
embedding_parser = subparsers.add_parser(
"embedding",
description="Generates embedding with OpenAI's Text Embedding models.",
help="Generates embedding with OpenAI's Text Embedding models.",
**PARSER_DEFAULTS, # type: ignore[arg-type]
)
embedding_parser.add_argument(
"text_file",
help="The text file to generate an embedding from.",
type=pathlib.Path,
)
embedding_parser.add_argument(
"output_file",
help="The name of the CSV output file.",
type=pathlib.Path,
)
embedding_parser.add_argument(
"-m",
"--model",
help=("The model to use."),
choices=["text-embedding-3-small", "text-embedding-3-large"],
default="text-embedding-3-large",
)
embedding_parser.add_argument(
"--keep-new-lines",
dest="keep_new_lines",
help=("Keeps line breaks in text if specified."),
action=argparse.BooleanOptionalAction,
default=False,
)


def _arg_validation(args: argparse.Namespace) -> argparse.Namespace:
"""Validate the parsed arguments.
Expand Down
15 changes: 15 additions & 0 deletions src/cloai/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid
from collections.abc import Generator

import aiocsv
import aiofiles
import aiohttp
import docx
Expand Down Expand Up @@ -71,6 +72,20 @@ async def save_file(filename: str | pathlib.Path, content: bytes) -> None:
await file.write(content)


async def save_csv(filename: str | pathlib.Path, content: list[float]) -> None:
"""Saves content to a csv file asynchronously.
Args:
----
filename: The name of the file to save the content to.
content: The content to save to the file.
"""
async with aiofiles.open(filename, "w") as file:
writer = aiocsv.AsyncWriter(file)
await writer.writerow(content)


async def download_file(filename: str | pathlib.Path, url: str) -> None:
"""Downloads a file from a URL.
Expand Down
33 changes: 33 additions & 0 deletions src/cloai/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,36 @@ async def run( # noqa: PLR0913
)

return [data.url for data in response.data]


class Embedding(OpenAIBaseClass):
"""A class for running the Embedding models."""

async def run(
self,
text: str,
model: Literal[
"text-embedding-3-small",
"text-embedding-3-large",
] = "text-embedding-3-large",
*,
keep_new_lines: bool = False,
) -> list[float]:
"""Runs the Embedding model.
Args:
text: the string to embed.
model: the name of the Embedding model to use.
keep_new_lines: Whether to keep or remove line breaks,
defaults to False.
Returns:
The embedding (list of numbers)
"""
if keep_new_lines is False:
text = text.replace("\n", " ")
response = await self.client.embeddings.create(
input=text,
model=model,
)
return response.data[0].embedding
21 changes: 21 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test configurations."""
import dataclasses
import os
from unittest import mock

Expand All @@ -18,6 +19,22 @@ def pytest_configure() -> None:
os.environ["OPENAI_API_KEY"] = "API_KEY"


@dataclasses.dataclass
class EmbeddingData:
"""A mock embedding data."""

embedding: list[float] = dataclasses.field(default_factory=lambda: [1.0, 2.0, 3.0])


@dataclasses.dataclass
class EmbeddingResponse:
"""A mock embedding response."""

data: list[EmbeddingData] = dataclasses.field(
default_factory=lambda: [EmbeddingData()],
)


@pytest.fixture()
def mock_openai(mocker: pytest_mock.MockFixture) -> mock.MagicMock:
"""Mocks the OpenAI client."""
Expand All @@ -33,11 +50,15 @@ def mock_openai(mocker: pytest_mock.MockFixture) -> mock.MagicMock:
create=mocker.AsyncMock(),
),
)
mock_embedding = mocker.MagicMock(
create=mocker.AsyncMock(return_value=EmbeddingResponse()),
)
mock_client = mocker.AsyncMock(
spec=openai_api.openai.AsyncOpenAI,
audio=mock_audio,
images=mock_images,
chat=mock_chat,
embeddings=mock_embedding,
)
return mocker.patch(
"cloai.openai_api.openai.AsyncOpenAI",
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest import mock

import pytest
from pytest_mock import plugin

from cloai.cli import commands
from cloai.core import exceptions
Expand Down Expand Up @@ -145,3 +146,29 @@ async def test_chat_completion_run_method(mock_openai: mock.MagicMock) -> None:
0
].message.content
)


# commands test to make sure input/output files are handled correctly
@pytest.mark.asyncio()
async def test_get_embedding(
mocker: plugin.MockerFixture,
tmp_path: pathlib.Path,
mock_openai: mock.AsyncMock,
) -> None:
"""Tests the get_embedding command."""
text_file = tmp_path / "test_text.txt"
expected_text = "test text"
text_file.write_text(expected_text)
output_file = tmp_path / "test_output.csv"
expected_embedding = mock_openai.return_value.embeddings.create.return_value.data[
0
].embedding
mock_save_csv = mocker.patch("cloai.core.utils.save_csv")

await commands.get_embedding(text_file, output_file)

assert (
mock_openai.return_value.embeddings.create.call_args[1]["input"]
== expected_text
)
mock_save_csv.assert_called_once_with(output_file, expected_embedding)
15 changes: 15 additions & 0 deletions tests/unit/test_openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,18 @@ async def test_chat_completion(mock_openai: mock.AsyncMock) -> None:
assert chat_completion.client is not None
assert mock_openai.call_count == 1
assert mock_openai.return_value.chat.completions.create.call_count == 1


@pytest.mark.asyncio()
async def test_embedding(mock_openai: mock.AsyncMock) -> None:
"""Tests the Embedding class."""
get_embedding = openai_api.Embedding()

await get_embedding.run(
"",
model="text-embedding-3-large",
)

assert get_embedding.client is not None
assert mock_openai.call_count == 1
assert mock_openai.return_value.embeddings.create.call_count == 1
31 changes: 31 additions & 0 deletions tests/unit/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,37 @@ def test__add_tts_parser() -> None:
assert arguments[4].default == "onyx"


def test__add_embedding_parser() -> None:
"""Tests the _add_embedding_parser function."""
subparsers = argparse.ArgumentParser().add_subparsers()
parser._add_embedding_parser(subparsers)
expected_n_arguments = 5
embedding_parser = subparsers.choices["embedding"]
arguments = embedding_parser._actions

assert "embedding" in subparsers.choices
assert (
embedding_parser.description
== "Generates embedding with OpenAI's Text Embedding models."
)

assert len(arguments) == expected_n_arguments

assert arguments[0].dest == "help"

assert arguments[1].dest == "text_file"
assert arguments[1].type == pathlib.Path

assert arguments[2].dest == "output_file"
assert arguments[2].type == pathlib.Path

assert arguments[3].dest == "model"
assert arguments[3].default == "text-embedding-3-large"

assert arguments[4].dest == "keep_new_lines"
assert arguments[4].default is False


@pytest.mark.asyncio()
async def test_run_command_without_arguments() -> None:
"""Tests the run_command function with no arguments."""
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import pathlib
import tempfile

import aiocsv
import aiofiles
import aioresponses
import pytest
import pytest_mock
Expand Down Expand Up @@ -40,3 +42,18 @@ async def test_download_file(tmp_path: pathlib.Path) -> None:
await utils.download_file(test_file_path, test_url)

assert test_file_path.read_bytes() == test_file_contents


@pytest.mark.asyncio()
async def test_save_csv(tmp_path: pathlib.Path) -> None:
"""Tests that the content is saved to a csv file asynchronously."""
test_filename = tmp_path / "test.csv"
test_content = [1.0, 2.0, 3.0, 4.0, 5.0]

await utils.save_csv(test_filename, test_content)

async with aiofiles.open(test_filename, "r") as file:
async for row in aiocsv.AsyncReader(file):
content = row

assert content == ["1.0", "2.0", "3.0", "4.0", "5.0"]

0 comments on commit 58d18ac

Please sign in to comment.