From 3570809dba801848ad34b802217f11ec89f8e8c4 Mon Sep 17 00:00:00 2001 From: Reinder Vos de Wael Date: Thu, 8 Feb 2024 14:29:27 -0500 Subject: [PATCH] Add language argument to whisper (#36) * Add language argument to whisper * Version bump * Fix type error --- pyproject.toml | 2 +- src/cloai/cli/commands.py | 7 +++-- src/cloai/cli/parser.py | 8 +++++ src/cloai/core/config.py | 63 +++++++++++++++++++++++++++++++++++++++ src/cloai/openai_api.py | 25 +++++++++------- tests/unit/test_parser.py | 6 ++-- 6 files changed, 95 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index edd0892..ae6c20f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "cloai" -version = "0.0.1a10" +version = "0.0.1a11" description = "A CLI for OpenAI's API" authors = ["Reinder Vos de Wael "] license = "LGPL-2.1" diff --git a/src/cloai/cli/commands.py b/src/cloai/cli/commands.py index ea2401b..c62a7a2 100644 --- a/src/cloai/cli/commands.py +++ b/src/cloai/cli/commands.py @@ -169,14 +169,15 @@ async def speech_to_text( model: str, *, clip: bool = False, + language: config.WhisperLanguages = config.WhisperLanguages.ENGLISH, ) -> str: """Transcribes audio files with OpenAI's TTS models. Args: filename: The file to transcribe. Can be any format that ffmpeg supports. model: The transcription model to use. - voice: The voice to use. clip: Whether to clip the file if it is too large, defaults to False. + language: The language used in the audio file. """ logger.debug("Transcribing audio.") with tempfile.TemporaryDirectory() as temp_dir: @@ -189,7 +190,9 @@ async def speech_to_text( files = [temp_file] stt = openai_api.SpeechToText() - transcription_promises = [stt.run(filename, model=model) for filename in files] + transcription_promises = [ + stt.run(filename, model=model, language=language) for filename in files + ] transcriptions = await asyncio.gather(*transcription_promises) return " ".join(transcriptions) diff --git a/src/cloai/cli/parser.py b/src/cloai/cli/parser.py index 4bf8f84..1b3eb1b 100644 --- a/src/cloai/cli/parser.py +++ b/src/cloai/cli/parser.py @@ -120,6 +120,7 @@ async def run_command(args: argparse.Namespace) -> str | bytes | None: filename=args.filename, model=args.model, clip=args.clip, + language=config.WhisperLanguages[args.language], ) msg = f"Unknown command {args.command}." raise exceptions.InvalidArgumentError(msg) @@ -235,6 +236,13 @@ def _add_stt_parser( choices=["whisper-1"], default="whisper-1", ) + stt_parser.add_argument( + "--language", + help="The language of the audio file.", + type=lambda x: x.upper(), + choices=[language.name for language in config.WhisperLanguages], + default="ENGLISH", + ) def _add_tts_parser( diff --git a/src/cloai/core/config.py b/src/cloai/core/config.py index db981f1..6f0b515 100644 --- a/src/cloai/core/config.py +++ b/src/cloai/core/config.py @@ -1,4 +1,5 @@ """Configuration for the cloai module.""" +import enum import functools import logging import pathlib @@ -17,6 +18,68 @@ def get_version() -> str: return "unknown" +class WhisperLanguages(str, enum.Enum): + """The languages for the whisper model.""" + + AFRIKAANS = "af" + ARABIC = "ar" + ARMENIAN = "hy" + AZERBAIJANI = "az" + BELARUSIAN = "be" + BOSNIAN = "bs" + BULGARIAN = "bg" + CATALAN = "ca" + CHINESE = "zh" + CROATIAN = "hr" + CZECH = "cs" + DANISH = "da" + DUTCH = "nl" + ENGLISH = "en" + ESTONIAN = "et" + FINNISH = "fi" + FRENCH = "fr" + GALICIAN = "gl" + GERMAN = "de" + GREEK = "el" + HEBREW = "he" + HINDI = "hi" + HUNGARIAN = "hu" + ICELANDIC = "is" + INDONESIAN = "id" + ITALIAN = "it" + JAPANESE = "ja" + KANNADA = "kn" + KAZAKH = "kk" + KOREAN = "ko" + LATVIAN = "lv" + LITHUANIAN = "lt" + MACEDONIAN = "mk" + MALAY = "ms" + MARATHI = "mr" + MAORI = "mi" + NEPALI = "ne" + NORWEGIAN = "no" + PERSIAN = "fa" + POLISH = "pl" + PORTUGUESE = "pt" + ROMANIAN = "ro" + RUSSIAN = "ru" + SERBIAN = "sr" + SLOVAK = "sk" + SLOVENIAN = "sl" + SPANISH = "es" + SWAHILI = "sw" + SWEDISH = "sv" + TAGALOG = "tl" + TAMIL = "ta" + THAI = "th" + TURKISH = "tr" + UKRAINIAN = "uk" + URDU = "ur" + VIETNAMESE = "vi" + WELSH = "cy" + + class Settings(pydantic_settings.BaseSettings): """Represents the settings for the cloai module.""" diff --git a/src/cloai/openai_api.py b/src/cloai/openai_api.py index 6c39d78..a4247ff 100644 --- a/src/cloai/openai_api.py +++ b/src/cloai/openai_api.py @@ -3,16 +3,13 @@ import abc import logging -from typing import TYPE_CHECKING, Any, Literal, TypedDict +import pathlib +from typing import Any, Literal, TypedDict -import aiofiles import openai from cloai.core import config, exceptions -if TYPE_CHECKING: - import pathlib - settings = config.get_settings() OPENAI_API_KEY = settings.OPENAI_API_KEY LOGGER_NAME = settings.LOGGER_NAME @@ -114,22 +111,28 @@ async def run( self, audio_file: pathlib.Path | str, model: str = "whisper-1", + language: config.WhisperLanguages | str = config.WhisperLanguages.ENGLISH, ) -> str: """Runs the Speech-To-Text model. Args: audio_file: The audio to convert to text. model: The name of the Speech-To-Text model to use. + language: The language of the audio. Can be both provided through the + config.WhisperLanguages enum, which guarantees support, or as a string. Returns: The model's response. """ - async with aiofiles.open(audio_file, "rb") as audio: - return await self.client.audio.transcriptions.create( - model=model, - file=audio, # type: ignore[arg-type] - response_format="text", - ) # type: ignore[return-value] # response_format overrides output type. + if isinstance(language, config.WhisperLanguages): + language = language.value + + return await self.client.audio.transcriptions.create( + model=model, + file=pathlib.Path(audio_file), + response_format="text", + language=language, + ) # type: ignore[return-value] # response_format overrides output type. class ImageGeneration(OpenAIBaseClass): diff --git a/tests/unit/test_parser.py b/tests/unit/test_parser.py index fd2aa8a..fcbd5d5 100644 --- a/tests/unit/test_parser.py +++ b/tests/unit/test_parser.py @@ -10,7 +10,7 @@ import pytest from cloai.cli import parser -from cloai.core import exceptions +from cloai.core import config, exceptions if TYPE_CHECKING: import pytest_mock @@ -99,7 +99,7 @@ def test__add_stt_parser() -> None: """Tests the _add_stt_parser function.""" subparsers = argparse.ArgumentParser().add_subparsers() parser._add_stt_parser(subparsers) - expected_n_arguments = 4 + expected_n_arguments = 5 stt_parser = subparsers.choices["whisper"] arguments = stt_parser._actions @@ -180,6 +180,7 @@ async def test_run_command_with_whisper(mocker: pytest_mock.MockFixture) -> None "filename": "test.wav", "clip": False, "model": "whisper-1", + "language": "ENGLISH", } args = argparse.Namespace(**arg_dict) mock = mocker.patch("cloai.cli.commands.speech_to_text") @@ -190,6 +191,7 @@ async def test_run_command_with_whisper(mocker: pytest_mock.MockFixture) -> None filename=arg_dict["filename"], clip=False, model=arg_dict["model"], + language=config.WhisperLanguages.ENGLISH, )