Skip to content

Commit

Permalink
Add language argument to whisper (#36)
Browse files Browse the repository at this point in the history
* Add language argument to whisper

* Version bump

* Fix type error
  • Loading branch information
ReinderVosDeWael authored Feb 8, 2024
1 parent 42ff90e commit 3570809
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "cloai"
version = "0.0.1a10"
version = "0.0.1a11"
description = "A CLI for OpenAI's API"
authors = ["Reinder Vos de Wael <[email protected]>"]
license = "LGPL-2.1"
Expand Down
7 changes: 5 additions & 2 deletions src/cloai/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,15 @@ async def speech_to_text(
model: str,
*,
clip: bool = False,
language: config.WhisperLanguages = config.WhisperLanguages.ENGLISH,
) -> str:
"""Transcribes audio files with OpenAI's TTS models.
Args:
filename: The file to transcribe. Can be any format that ffmpeg supports.
model: The transcription model to use.
voice: The voice to use.
clip: Whether to clip the file if it is too large, defaults to False.
language: The language used in the audio file.
"""
logger.debug("Transcribing audio.")
with tempfile.TemporaryDirectory() as temp_dir:
Expand All @@ -189,7 +190,9 @@ async def speech_to_text(
files = [temp_file]

stt = openai_api.SpeechToText()
transcription_promises = [stt.run(filename, model=model) for filename in files]
transcription_promises = [
stt.run(filename, model=model, language=language) for filename in files
]
transcriptions = await asyncio.gather(*transcription_promises)

return " ".join(transcriptions)
Expand Down
8 changes: 8 additions & 0 deletions src/cloai/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ async def run_command(args: argparse.Namespace) -> str | bytes | None:
filename=args.filename,
model=args.model,
clip=args.clip,
language=config.WhisperLanguages[args.language],
)
msg = f"Unknown command {args.command}."
raise exceptions.InvalidArgumentError(msg)
Expand Down Expand Up @@ -235,6 +236,13 @@ def _add_stt_parser(
choices=["whisper-1"],
default="whisper-1",
)
stt_parser.add_argument(
"--language",
help="The language of the audio file.",
type=lambda x: x.upper(),
choices=[language.name for language in config.WhisperLanguages],
default="ENGLISH",
)


def _add_tts_parser(
Expand Down
63 changes: 63 additions & 0 deletions src/cloai/core/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Configuration for the cloai module."""
import enum
import functools
import logging
import pathlib
Expand All @@ -17,6 +18,68 @@ def get_version() -> str:
return "unknown"


class WhisperLanguages(str, enum.Enum):
"""The languages for the whisper model."""

AFRIKAANS = "af"
ARABIC = "ar"
ARMENIAN = "hy"
AZERBAIJANI = "az"
BELARUSIAN = "be"
BOSNIAN = "bs"
BULGARIAN = "bg"
CATALAN = "ca"
CHINESE = "zh"
CROATIAN = "hr"
CZECH = "cs"
DANISH = "da"
DUTCH = "nl"
ENGLISH = "en"
ESTONIAN = "et"
FINNISH = "fi"
FRENCH = "fr"
GALICIAN = "gl"
GERMAN = "de"
GREEK = "el"
HEBREW = "he"
HINDI = "hi"
HUNGARIAN = "hu"
ICELANDIC = "is"
INDONESIAN = "id"
ITALIAN = "it"
JAPANESE = "ja"
KANNADA = "kn"
KAZAKH = "kk"
KOREAN = "ko"
LATVIAN = "lv"
LITHUANIAN = "lt"
MACEDONIAN = "mk"
MALAY = "ms"
MARATHI = "mr"
MAORI = "mi"
NEPALI = "ne"
NORWEGIAN = "no"
PERSIAN = "fa"
POLISH = "pl"
PORTUGUESE = "pt"
ROMANIAN = "ro"
RUSSIAN = "ru"
SERBIAN = "sr"
SLOVAK = "sk"
SLOVENIAN = "sl"
SPANISH = "es"
SWAHILI = "sw"
SWEDISH = "sv"
TAGALOG = "tl"
TAMIL = "ta"
THAI = "th"
TURKISH = "tr"
UKRAINIAN = "uk"
URDU = "ur"
VIETNAMESE = "vi"
WELSH = "cy"


class Settings(pydantic_settings.BaseSettings):
"""Represents the settings for the cloai module."""

Expand Down
25 changes: 14 additions & 11 deletions src/cloai/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@

import abc
import logging
from typing import TYPE_CHECKING, Any, Literal, TypedDict
import pathlib
from typing import Any, Literal, TypedDict

import aiofiles
import openai

from cloai.core import config, exceptions

if TYPE_CHECKING:
import pathlib

settings = config.get_settings()
OPENAI_API_KEY = settings.OPENAI_API_KEY
LOGGER_NAME = settings.LOGGER_NAME
Expand Down Expand Up @@ -114,22 +111,28 @@ async def run(
self,
audio_file: pathlib.Path | str,
model: str = "whisper-1",
language: config.WhisperLanguages | str = config.WhisperLanguages.ENGLISH,
) -> str:
"""Runs the Speech-To-Text model.
Args:
audio_file: The audio to convert to text.
model: The name of the Speech-To-Text model to use.
language: The language of the audio. Can be both provided through the
config.WhisperLanguages enum, which guarantees support, or as a string.
Returns:
The model's response.
"""
async with aiofiles.open(audio_file, "rb") as audio:
return await self.client.audio.transcriptions.create(
model=model,
file=audio, # type: ignore[arg-type]
response_format="text",
) # type: ignore[return-value] # response_format overrides output type.
if isinstance(language, config.WhisperLanguages):
language = language.value

return await self.client.audio.transcriptions.create(
model=model,
file=pathlib.Path(audio_file),
response_format="text",
language=language,
) # type: ignore[return-value] # response_format overrides output type.


class ImageGeneration(OpenAIBaseClass):
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest

from cloai.cli import parser
from cloai.core import exceptions
from cloai.core import config, exceptions

if TYPE_CHECKING:
import pytest_mock
Expand Down Expand Up @@ -99,7 +99,7 @@ def test__add_stt_parser() -> None:
"""Tests the _add_stt_parser function."""
subparsers = argparse.ArgumentParser().add_subparsers()
parser._add_stt_parser(subparsers)
expected_n_arguments = 4
expected_n_arguments = 5

stt_parser = subparsers.choices["whisper"]
arguments = stt_parser._actions
Expand Down Expand Up @@ -180,6 +180,7 @@ async def test_run_command_with_whisper(mocker: pytest_mock.MockFixture) -> None
"filename": "test.wav",
"clip": False,
"model": "whisper-1",
"language": "ENGLISH",
}
args = argparse.Namespace(**arg_dict)
mock = mocker.patch("cloai.cli.commands.speech_to_text")
Expand All @@ -190,6 +191,7 @@ async def test_run_command_with_whisper(mocker: pytest_mock.MockFixture) -> None
filename=arg_dict["filename"],
clip=False,
model=arg_dict["model"],
language=config.WhisperLanguages.ENGLISH,
)


Expand Down

0 comments on commit 3570809

Please sign in to comment.