Skip to content

Commit

Permalink
Update audio transcription tool and usage
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudotensor committed Sep 20, 2024
1 parent dd8ad17 commit c322109
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 15 deletions.
5 changes: 3 additions & 2 deletions openai_server/agent_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,9 +579,10 @@ def get_audio_transcription_helper():
```sh
# filename: my_audio_transcription.sh
# execution: true
python {cwd}/openai_server/agent_tools/audio_transcription.py --file_path "./audio.wav"
python {cwd}/openai_server/agent_tools/audio_transcription.py --input "audio.wav"
```
* usage: python {cwd}/openai_server/agent_tools/audio_transcription.py [-h] --file_path FILE_PATH
* usage: python {cwd}/openai_server/agent_tools/audio_transcription.py [-h] --input "AUDIO_FILE_PATH"
* Can transcribe audio from mp3, mp4, mpeg, mpga, m4a, wav, webm, and more.
"""
else:
audio_transcription = ''
Expand Down
66 changes: 54 additions & 12 deletions openai_server/agent_tools/audio_transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,72 @@
from openai import OpenAI


def check_valid_extension(file):
"""
OpenAI only allows certain file types
:param file:
:return:
"""
valid_extensions = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm']

# Get the file extension (convert to lowercase for case-insensitive comparison)
_, file_extension = os.path.splitext(file)
file_extension = file_extension.lower().lstrip('.')

if file_extension not in valid_extensions:
raise ValueError(
f"Invalid file extension. Expected one of {', '.join(valid_extensions)}, but got '{file_extension}'")

return True


def main():
parser = argparse.ArgumentParser(description="Get transcription of an audio file")
parser.add_argument("--input", type=str, required=True, help="Path to the input audio file")
# Model
parser.add_argument("--model", type=str, required=False, help="Model name")
parser.add_argument("--model", type=str, required=False,
help="Model name (For Azure deployment name must match actual model name, e.g. whisper-1)")
# File name
parser.add_argument("--output", type=str, default='', required=False, help="Path (ensure unique) to the audio file")
parser.add_argument("--output", type=str, default='', required=False,
help="Path (ensure unique) to output text file")
args = parser.parse_args()
##
stt_url = os.getenv("STT_OPENAI_BASE_URL", None)
assert stt_url is not None, "STT_OPENAI_BASE_URL environment variable is not set"
stt_api_key = os.getenv('STT_OPENAI_API_KEY', 'EMPTY')

stt_api_key = os.getenv('STT_OPENAI_API_KEY')
if stt_url == "https://api.openai.com/v1" or 'openai.azure.com' in stt_url:
assert stt_api_key, "STT_OPENAI_API_KEY environment variable is not set and is required if using OpenAI or Azure endpoints"

if 'openai.azure.com' in stt_url:
# https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line%2Cpython-new%2Cjavascript&pivots=programming-language-python
from openai import AzureOpenAI
client = AzureOpenAI(
api_version="2024-02-01",
api_key=stt_api_key,
# like base_url, but Azure endpoint like https://PROJECT.openai.azure.com/
azure_endpoint=stt_url,
)
else:
from openai import OpenAI
client = OpenAI(base_url=stt_url, api_key=stt_api_key)

check_valid_extension(args.input)
else:
from openai import OpenAI
stt_api_key = os.getenv('STT_OPENAI_API_KEY', 'EMPTY')
client = OpenAI(base_url=stt_url, api_key=stt_api_key)

if not args.model:
stt_model = os.getenv('STT_OPENAI_MODEL')
assert stt_model is not None, "STT_OPENAI_MODEL environment variable is not set"
args.model = stt_model
args.model = os.getenv('STT_OPENAI_MODEL', 'whisper-1')

# Read the audio file
audio_file = open(args.file_path, "rb")
client = OpenAI(base_url=stt_url, api_key=stt_api_key)
transcription = client.audio.transcriptions.create(
model=args.model,
file=audio_file
)
with open(args.input, "rb") as f:
transcription = client.audio.transcriptions.create(
model=args.model,
file=f.read(),
response_format="text",
)
# Save the image to a file
if not args.output:
args.output = f"transcription_{str(uuid.uuid4())[:6]}.txt"
Expand Down
2 changes: 1 addition & 1 deletion src/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "cd495c13b6eaa4ab75ce63001645ef2543a83deb"
__version__ = "dd8ad17cad368e3c2900f74b68cab7799cf66083"

0 comments on commit c322109

Please sign in to comment.