Skip to content

Commit

Permalink
feat: add support for Azure OpenAI (#87)
Browse files Browse the repository at this point in the history
* feat: add support for Azure OpenAI

* code review changes
  • Loading branch information
quitrk authored Jun 13, 2024
1 parent 97785c0 commit b8d8141
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 28 deletions.
20 changes: 17 additions & 3 deletions credentials.yaml.sample
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
customer_credentials:
test-customer_id:
api_key: sample-api-key
model_name: gpt-3.5-turbo
testCustomerId:
credentialsMap:
AZURE_OPENAI:
customerId: testCustomerId
enabled: true
metadata:
deploymentName: gpt-4o
endpoint: https://myinstance.openai.azure.com/
secret: test_secret
type: AZURE_OPENAI
OPENAI:
customerId: testCustomerId
enabled: false
metadata:
model: gpt-3
secret: test_secret
type: OPENAI
18 changes: 17 additions & 1 deletion skynet/auth/openai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from enum import Enum

import aiofiles
import yaml

Expand All @@ -10,6 +12,11 @@
credentials = dict()


class CredentialsType(Enum):
OPENAI = 'OPENAI'
AZURE_OPENAI = 'AZURE_OPENAI'


async def open_yaml(file_path):
try:
async with aiofiles.open(file_path, mode='r') as file:
Expand Down Expand Up @@ -40,4 +47,13 @@ async def setup_credentials():


def get_credentials(customer_id):
return credentials.get(customer_id, {}) or {}
customer_credentials = credentials.get(customer_id, {}) or {}
multiple_credentials = customer_credentials.get('credentialsMap')

if multiple_credentials:
result = [val for val in multiple_credentials.values() if val['enabled']]
return result[0] if result else {}

# backwards compatibility
customer_credentials.setdefault('type', CredentialsType.OPENAI.value)
return customer_credentials
4 changes: 4 additions & 0 deletions skynet/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def tobool(val: str | None):
llama_n_gpu_layers = int(os.environ.get('LLAMA_N_GPU_LAYERS', -1 if is_mac else 40))
llama_n_batch = int(os.environ.get('LLAMA_N_BATCH', 512))

# azure openai api
# latest ga version https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation#latest-ga-api-release
azure_openai_api_version = os.environ.get('AZURE_OPENAI_API_VERSION', '2024-02-01')

# openai api
openai_api_server_path = os.environ.get('OPENAI_API_SERVER_PATH', '/app/llama.cpp/server')
openai_api_server_port = int(os.environ.get('OPENAI_API_SERVER_PORT', 8002))
Expand Down
33 changes: 22 additions & 11 deletions skynet/modules/ttt/summaries/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
import uuid

from skynet.auth.openai import get_credentials
from skynet.auth.openai import CredentialsType, get_credentials

from skynet.env import job_timeout, modules, redis_exp_seconds, summary_minimum_payload_length
from skynet.logs import get_logger
Expand All @@ -16,7 +16,7 @@
from skynet.modules.ttt.openai_api.app import restart as restart_openai_api

from .persistence import db
from .processor import process, process_open_ai
from .processor import process, process_azure, process_open_ai
from .v1.models import DocumentMetadata, DocumentPayload, Job, JobId, JobStatus, JobType

log = get_logger(__name__)
Expand Down Expand Up @@ -125,15 +125,26 @@ async def run_job(job: Job) -> None:
exit_task = asyncio.create_task(restart_on_timeout(job))

try:
customer_id = job.metadata.customer_id
options = get_credentials(customer_id) if customer_id else {}
api_key = options.get('secret')
model_name = options.get('model')

if api_key:
log.info(f"Forwarding inference to OpenAI for customer {customer_id}")

result = await process_open_ai(job.payload, job.type, api_key, model_name)
customer_id = 'testCustomerId'
options = get_credentials(customer_id)
secret = options.get('secret')
api_type = options.get('type')

if secret:
if api_type == CredentialsType.OPENAI.value:
log.info(f"Forwarding inference to OpenAI for customer {customer_id}")

# needed for backwards compatibility
model = options.get('model') or options.get('metadata').get('model')
result = await process_open_ai(job.payload, job.type, secret, model)

elif api_type == CredentialsType.AZURE_OPENAI.value:
log.info(f"Forwarding inference to Azure openai for customer {customer_id}")

metadata = options.get('metadata')
result = await process_azure(
job.payload, job.type, secret, metadata.get('endpoint'), metadata.get('deploymentName')
)
else:
if customer_id:
log.info(f'Customer {customer_id} has no API key configured, falling back to local processing')
Expand Down
61 changes: 50 additions & 11 deletions skynet/modules/ttt/summaries/jobs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def run_job_fixture(mocker):
mocker.patch('skynet.modules.ttt.summaries.jobs.update_job')
mocker.patch('skynet.modules.ttt.summaries.jobs.process')
mocker.patch('skynet.modules.ttt.summaries.jobs.process_open_ai')
mocker.patch('skynet.modules.ttt.summaries.jobs.process_azure')
mocker.patch('skynet.modules.ttt.summaries.jobs.db.db')

return mocker
Expand Down Expand Up @@ -86,20 +87,58 @@ async def test_run_job_with_open_ai(self, run_job_fixture):

from skynet.modules.ttt.summaries.jobs import process_open_ai, run_job

run_job_fixture.patch('skynet.modules.ttt.summaries.jobs.get_credentials', return_value={'secret': 'api_key'})
secret = 'secret'
model = 'gpt-3.5-turbo'

await run_job(
Job(
payload=DocumentPayload(
text="Andrew: Hello. Beatrix: Honey? It’s me . . . Andrew: Where are you? Beatrix: At the station. I missed my train."
),
metadata=DocumentMetadata(customer_id='test'),
type=JobType.SUMMARY,
id='job_id',
)
run_job_fixture.patch(
'skynet.modules.ttt.summaries.jobs.get_credentials',
return_value={'secret': secret, 'type': 'OPENAI', 'metadata': {'model': model}},
)

process_open_ai.assert_called_once()
job = Job(
payload=DocumentPayload(
text="Andrew: Hello. Beatrix: Honey? It’s me . . . Andrew: Where are you? Beatrix: At the station. I missed my train."
),
metadata=DocumentMetadata(customer_id='test'),
type=JobType.SUMMARY,
id='job_id',
)

await run_job(job)

process_open_ai.assert_called_once_with(job.payload, job.type, secret, model)

@pytest.mark.asyncio
async def test_run_job_with_azure_open_ai(self, run_job_fixture):
'''Test that a job is sent for inference to azure openai if there is a customer id with a valid api key.'''

from skynet.modules.ttt.summaries.jobs import process_azure, run_job

secret = 'secret'
deployment_name = 'gpt-3.5-turbo'
endpoint = 'https://myopenai.azure.com'

run_job_fixture.patch(
'skynet.modules.ttt.summaries.jobs.get_credentials',
return_value={
'secret': secret,
'type': 'AZURE_OPENAI',
'metadata': {'deploymentName': deployment_name, 'endpoint': endpoint},
},
)

job = Job(
payload=DocumentPayload(
text="Andrew: Hello. Beatrix: Honey? It’s me . . . Andrew: Where are you? Beatrix: At the station. I missed my train."
),
metadata=DocumentMetadata(customer_id='test'),
type=JobType.SUMMARY,
id='job_id',
)

await run_job(job)

process_azure.assert_called_once_with(job.payload, job.type, secret, endpoint, deployment_name)


class TestCanRunNextJob:
Expand Down
18 changes: 16 additions & 2 deletions skynet/modules/ttt/summaries/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from langchain.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI

from skynet.env import app_uuid, llama_n_ctx, openai_api_base_url
from skynet.env import app_uuid, azure_openai_api_version, llama_n_ctx, openai_api_base_url
from skynet.logs import get_logger

from .prompts.action_items import action_items_conversation_prompt, action_items_text_prompt
Expand Down Expand Up @@ -83,3 +83,17 @@ async def process_open_ai(payload: DocumentPayload, job_type: JobType, api_key:
)

return await process(payload, job_type, llm)


async def process_azure(
payload: DocumentPayload, job_type: JobType, api_key: str, endpoint: str, deployment_name: str
) -> str:
llm = AzureChatOpenAI(
api_key=api_key,
api_version=azure_openai_api_version,
azure_endpoint=endpoint,
azure_deployment=deployment_name,
temperature=0,
)

return await process(payload, job_type, llm)

0 comments on commit b8d8141

Please sign in to comment.