Skip to content

Commit

Permalink
feat(ttt) refactor job processing
Browse files Browse the repository at this point in the history
Extract the logic to select the right llm to an "LLM selector" entity.
This simplifies the processor to only worry about the processing part,
and the whole selection is encapsulated in LLMSelector.

If / when we need to make model selection more complex than just decide
which processor to use, it will be contained here.

While working on this I also noticed a logic error when initializing the
OCI LLM: the max tokens was taken from the first processed job.

Contrary to my original assesment, the transformers download is done
just once, so we don't need to reuse the ChatOCIGenAI instance and can
set the max_tokens property appropriately.
  • Loading branch information
saghul committed Feb 20, 2025
1 parent a96150b commit a9fafb6
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 177 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ models
.DS_Store
.env
.idea
.vscode
llama.log
dump.rdb
_vector_store_
111 changes: 111 additions & 0 deletions skynet/modules/ttt/llm_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Optional

from langchain_community.chat_models import ChatOCIGenAI
from langchain_core.language_models.chat_models import BaseChatModel

from langchain_openai import AzureChatOpenAI, ChatOpenAI

from skynet.auth.user_info import CredentialsType, get_credentials
from skynet.env import (
app_uuid,
azure_openai_api_version,
llama_path,
oci_auth_type,
oci_available,
oci_compartment_id,
oci_config_profile,
oci_model_id,
oci_service_endpoint,
openai_api_base_url,
use_oci,
)
from skynet.logs import get_logger
from skynet.modules.ttt.summaries.v1.models import Processors

log = get_logger(__name__)


class LLMSelector:
@staticmethod
def get_job_processor(customer_id: str) -> Processors:
options = get_credentials(customer_id)
secret = options.get('secret')
api_type = options.get('type')

if secret:
if api_type == CredentialsType.OPENAI.value:
return Processors.OPENAI
elif api_type == CredentialsType.AZURE_OPENAI.value:
return Processors.AZURE

# OCI doesn't have a secret since it's provisioned for the instance as a whole.
if use_oci or api_type == CredentialsType.OCI.value:
if oci_available:
return Processors.OCI
else:
log.warning(f'OCI is not available, falling back to local processing for customer {customer_id}')

return Processors.LOCAL

@staticmethod
def select(customer_id: str, max_completion_tokens: Optional[int] = None) -> BaseChatModel:
processor = LLMSelector.get_job_processor(customer_id)
options = get_credentials(customer_id)

if processor == Processors.OPENAI:
log.info(f'Forwarding inference to OpenAI for customer {customer_id}')

return ChatOpenAI(
api_key=options.get('secret'),
max_completion_tokens=max_completion_tokens,
model_name=options.get('metadata').get('model'),
temperature=0,
)
elif processor == Processors.AZURE:
log.info(f'Forwarding inference to Azure-OpenAI for customer {customer_id}')

metadata = options.get('metadata')

return AzureChatOpenAI(
api_key=options.get('secret'),
api_version=azure_openai_api_version,
azure_endpoint=metadata.get('endpoint'),
azure_deployment=metadata.get('deploymentName'),
max_completion_tokens=max_completion_tokens,
temperature=0,
)
elif processor == Processors.OCI:
log.info(f'Forwarding inference to OCI for customer {customer_id}')

model_kwargs = {
'temperature': 0,
'frequency_penalty': 1,
'max_tokens': max_completion_tokens,
}

return ChatOCIGenAI(
model_id=oci_model_id,
service_endpoint=oci_service_endpoint,
compartment_id=oci_compartment_id,
provider='meta',
model_kwargs=model_kwargs,
auth_type=oci_auth_type,
auth_profile=oci_config_profile,
)
else:
if customer_id:
log.info(f'Customer {customer_id} has no API key configured, falling back to local processing')

return ChatOpenAI(
model=llama_path,
api_key='placeholder', # use a placeholder value to bypass validation
base_url=f'{openai_api_base_url}/v1',
default_headers={'X-Skynet-UUID': app_uuid},
frequency_penalty=1,
max_retries=0,
temperature=0,
max_completion_tokens=max_completion_tokens,
)


llm_selector = LLMSelector()
179 changes: 18 additions & 161 deletions skynet/modules/ttt/processor.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,20 @@
from operator import itemgetter
from typing import Optional

from langchain.chains.summarize import load_summarize_chain
from langchain.prompts import ChatPromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.chat_models import ChatOCIGenAI
from langchain_core.documents import Document
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import AzureChatOpenAI, ChatOpenAI

from skynet.auth.user_info import CredentialsType, get_credentials
from skynet.env import (
app_uuid,
azure_openai_api_version,
llama_n_ctx,
llama_path,
oci_auth_type,
oci_compartment_id,
oci_config_profile,
oci_model_id,
oci_service_endpoint,
openai_api_base_url,
use_oci,
)

from skynet.env import llama_n_ctx
from skynet.logs import get_logger
from skynet.modules.ttt.assistant.constants import assistant_rag_question_extractor
from skynet.modules.ttt.assistant.utils import get_assistant_chat_messages

from skynet.modules.ttt.llm_selector import LLMSelector
from skynet.modules.ttt.rag.app import get_vector_store
from skynet.modules.ttt.summaries.prompts.action_items import (
action_items_conversation,
Expand All @@ -42,7 +28,7 @@
summary_meeting,
summary_text,
)
from skynet.modules.ttt.summaries.v1.models import DocumentPayload, HintType, JobType, Processors
from skynet.modules.ttt.summaries.v1.models import DocumentPayload, HintType, JobType

log = get_logger(__name__)

Expand Down Expand Up @@ -72,69 +58,10 @@ def format_docs(docs: list[Document]) -> str:
return '\n\n'.join(doc.page_content for doc in docs)


def get_job_processor(customer_id: str) -> Processors:
options = get_credentials(customer_id)
secret = options.get('secret')
api_type = options.get('type')

if secret:
if api_type == CredentialsType.OPENAI.value:
return Processors.OPENAI
elif api_type == CredentialsType.AZURE_OPENAI.value:
return Processors.AZURE

# OCI doesn't have a secret since it's provisioned for the instance as a whole.
if api_type == CredentialsType.OCI.value:
return Processors.OCI

return Processors.LOCAL


# Cached instance since it performs some initialization we'd
# like to avoid on every request.
oci_llm = None


def get_oci_llm(max_tokens):
global oci_llm

if oci_llm is None:
oci_llm = ChatOCIGenAI(
model_id=oci_model_id,
service_endpoint=oci_service_endpoint,
compartment_id=oci_compartment_id,
provider="meta",
model_kwargs={"temperature": 0, "frequency_penalty": 1, "max_tokens": max_tokens},
auth_type=oci_auth_type,
auth_profile=oci_config_profile,
)
return oci_llm


def get_local_llm(**kwargs):
# OCI hosted llama
if use_oci:
return get_oci_llm(kwargs['max_completion_tokens'])

# Locally hosted llama
return ChatOpenAI(
model=llama_path,
api_key='placeholder', # use a placeholder value to bypass validation, and allow the custom base url to be used
base_url=f'{openai_api_base_url}/v1',
default_headers={'X-Skynet-UUID': app_uuid},
frequency_penalty=1,
max_retries=0,
temperature=0,
**kwargs,
)


compressor = FlashrankRerank()


async def assist(payload: DocumentPayload, customer_id: str | None = None, model: BaseChatModel = None) -> str:
current_model = model or get_local_llm(max_completion_tokens=payload.max_completion_tokens)

async def assist(model: BaseChatModel, payload: DocumentPayload, customer_id: Optional[str] = None) -> str:
store = await get_vector_store()
vector_store = await store.get(customer_id)
config = await store.get_config(customer_id)
Expand All @@ -149,7 +76,8 @@ async def assist(payload: DocumentPayload, customer_id: str | None = None, model

if retriever and payload.text:
question_payload = DocumentPayload(**(payload.model_dump() | {'prompt': assistant_rag_question_extractor}))
question = await summarize(question_payload, JobType.SUMMARY, current_model)
# TODO: add a generic document processor and use that.
question = await summarize(model, question_payload, JobType.SUMMARY)

log.info(f'Using question: {question}')

Expand All @@ -166,15 +94,14 @@ async def assist(payload: DocumentPayload, customer_id: str | None = None, model
rag_chain = (
{'context': (itemgetter('question') | retriever | format_docs) if retriever else lambda _: ''}
| template
| current_model
| model
| StrOutputParser()
)

return await rag_chain.ainvoke(input={'question': question})


async def summarize(payload: DocumentPayload, job_type: JobType, model: BaseChatModel = None) -> str:
current_model = model or get_local_llm(max_completion_tokens=payload.max_completion_tokens)
async def summarize(model: BaseChatModel, payload: DocumentPayload, job_type: JobType) -> str:
chain = None
text = payload.text

Expand All @@ -191,13 +118,14 @@ async def summarize(payload: DocumentPayload, job_type: JobType, model: BaseChat
)

# this is a rough estimate of the number of tokens in the input text, since llama models will have a different tokenization scheme
num_tokens = current_model.get_num_tokens(text)
num_tokens = model.get_num_tokens(text)

# allow some buffer for the model to generate the output
# TODO: adjust this to the actual model's context window
threshold = llama_n_ctx * 3 / 4

if num_tokens < threshold:
chain = load_summarize_chain(current_model, chain_type='stuff', prompt=prompt)
chain = load_summarize_chain(model, chain_type='stuff', prompt=prompt)
docs = [Document(page_content=text)]
else:
# split the text into roughly equal chunks
Expand All @@ -208,7 +136,7 @@ async def summarize(payload: DocumentPayload, job_type: JobType, model: BaseChat

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=chunk_size, chunk_overlap=100)
docs = text_splitter.create_documents([text])
chain = load_summarize_chain(current_model, chain_type='map_reduce', combine_prompt=prompt, map_prompt=prompt)
chain = load_summarize_chain(model, chain_type='map_reduce', combine_prompt=prompt, map_prompt=prompt)

result = await chain.ainvoke(input={'input_documents': docs})
formatted_result = result['output_text'].replace('Response:', '', 1).strip()
Expand All @@ -219,83 +147,12 @@ async def summarize(payload: DocumentPayload, job_type: JobType, model: BaseChat
return formatted_result


async def process_open_ai(
payload: DocumentPayload, job_type: JobType, api_key: str, model_name=None, customer_id: str | None = None
) -> str:
llm = ChatOpenAI(
api_key=api_key,
max_completion_tokens=payload.max_completion_tokens,
model_name=model_name,
temperature=0,
)

if job_type == JobType.ASSIST:
return await assist(payload, customer_id, llm)

return await summarize(payload, job_type, llm)


async def process_azure(
payload: DocumentPayload,
job_type: JobType,
api_key: str,
endpoint: str,
deployment_name: str,
customer_id: str | None = None,
) -> str:
llm = AzureChatOpenAI(
api_key=api_key,
api_version=azure_openai_api_version,
azure_endpoint=endpoint,
azure_deployment=deployment_name,
max_completion_tokens=payload.max_completion_tokens,
temperature=0,
)

if job_type == JobType.ASSIST:
return await assist(payload, customer_id, llm)

return await summarize(payload, job_type, llm)


async def process_oci(payload: DocumentPayload, job_type: JobType, customer_id: str | None = None) -> str:
llm = get_oci_llm(payload.max_completion_tokens)

if job_type == JobType.ASSIST:
return await assist(payload, customer_id, llm)

return await summarize(payload, job_type, llm)


async def process(payload: DocumentPayload, job_type: JobType, customer_id: str | None = None) -> str:
processor = get_job_processor(customer_id)
options = get_credentials(customer_id)

secret = options.get('secret')

if processor == Processors.OPENAI:
log.info(f'Forwarding inference to OpenAI for customer {customer_id}')

model = options.get('metadata').get('model')
result = await process_open_ai(payload, job_type, secret, model, customer_id)
elif processor == Processors.AZURE:
log.info(f"Forwarding inference to Azure-OpenAI for customer {customer_id}")
llm = LLMSelector.select(customer_id, max_completion_tokens=payload.max_completion_tokens)

metadata = options.get('metadata')
result = await process_azure(
payload, job_type, secret, metadata.get('endpoint'), metadata.get('deploymentName'), customer_id
)
elif processor == Processors.OCI:
log.info(f"Forwarding inference to OCI for customer {customer_id}")

result = await process_oci(payload, job_type, customer_id)
if job_type == JobType.ASSIST:
result = await assist(llm, payload, customer_id)
else:
if customer_id:
log.info(f'Customer {customer_id} has no API key configured, falling back to local processing')

if job_type == JobType.ASSIST:
result = await assist(payload, customer_id)
else:
result = await summarize(payload, job_type)
result = await summarize(llm, payload, job_type)

return result
Loading

0 comments on commit a9fafb6

Please sign in to comment.