Skip to content

Commit

Permalink
fix: Psl 6635 - BYOD workflow implementation (#1265)
Browse files Browse the repository at this point in the history
Co-authored-by: Ajit Padhi (Persistent Systems Inc) <[email protected]>
Co-authored-by: Francia Riesco <[email protected]>
  • Loading branch information
3 people authored Sep 16, 2024
1 parent 53b899c commit 6998ab0
Show file tree
Hide file tree
Showing 12 changed files with 341 additions and 67 deletions.
1 change: 1 addition & 0 deletions code/backend/Admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def load_css(file_path):
"""
* If you want to ingest data (pdf, websites, etc.), then use the `Ingest Data` tab
* If you want to explore how your data was chunked, check the `Explore Data` tab
* If you want to delete your data, check the `Delete Data` tab
* If you want to adapt the underlying prompts, logging settings and others, use the `Configuration` tab
"""
)
20 changes: 17 additions & 3 deletions code/backend/batch/utilities/helpers/config/config_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ...orchestrator import OrchestrationSettings
from ..env_helper import EnvHelper
from .assistant_strategy import AssistantStrategy
from .conversation_flow import ConversationFlow

CONFIG_CONTAINER_NAME = "config"
CONFIG_FILE_NAME = "active.json"
Expand Down Expand Up @@ -90,6 +91,9 @@ def get_available_orchestration_strategies(self):
def get_available_ai_assistant_types(self):
return [c.value for c in AssistantStrategy]

def get_available_conversational_flows(self):
return [c.value for c in ConversationFlow]


# TODO: Change to AnsweringChain or something, Prompts is not a good name
class Prompts:
Expand All @@ -102,6 +106,7 @@ def __init__(self, prompts: dict):
self.enable_post_answering_prompt = prompts["enable_post_answering_prompt"]
self.enable_content_safety = prompts["enable_content_safety"]
self.ai_assistant_type = prompts["ai_assistant_type"]
self.conversational_flow = prompts["conversational_flow"]


class Example:
Expand Down Expand Up @@ -166,13 +171,20 @@ def _set_new_config_properties(config: dict, default_config: dict):
config["example"] = default_config["example"]

if config["prompts"].get("ai_assistant_type") is None:
config["prompts"]["ai_assistant_type"] = default_config["prompts"]["ai_assistant_type"]
config["prompts"]["ai_assistant_type"] = default_config["prompts"][
"ai_assistant_type"
]

if config.get("integrated_vectorization_config") is None:
config["integrated_vectorization_config"] = default_config[
"integrated_vectorization_config"
]

if config["prompts"].get("conversational_flow") is None:
config["prompts"]["conversational_flow"] = default_config["prompts"][
"conversational_flow"
]

@staticmethod
@functools.cache
def get_active_config_or_default():
Expand Down Expand Up @@ -247,12 +259,14 @@ def get_default_config():
@staticmethod
@functools.cache
def get_default_contract_assistant():
contract_file_path = os.path.join(os.path.dirname(__file__), "default_contract_assistant_prompt.txt")
contract_file_path = os.path.join(
os.path.dirname(__file__), "default_contract_assistant_prompt.txt"
)
contract_assistant = ""
with open(contract_file_path, encoding="utf-8") as f:
contract_assistant = f.readlines()

return ''.join([str(elem) for elem in contract_assistant])
return "".join([str(elem) for elem in contract_assistant])

@staticmethod
def clear_config():
Expand Down
3 changes: 2 additions & 1 deletion code/backend/batch/utilities/helpers/config/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
"use_on_your_data_format": true,
"enable_post_answering_prompt": false,
"ai_assistant_type": "default",
"enable_content_safety": true
"enable_content_safety": true,
"conversational_flow": "custom"
},
"example": {
"documents": "{\n \"retrieved_documents\": [\n {\n \"[doc1]\": {\n \"content\": \"Dual Transformer Encoder (DTE) DTE (https://dev.azure.com/TScience/TSciencePublic/_wiki/wikis/TSciencePublic.wiki/82/Dual-Transformer-Encoder) DTE is a general pair-oriented sentence representation learning framework based on transformers. It provides training, inference and evaluation for sentence similarity models. Model Details DTE can be used to train a model for sentence similarity with the following features: - Build upon existing transformer-based text representations (e.g.TNLR, BERT, RoBERTa, BAG-NLR) - Apply smoothness inducing technology to improve the representation robustness - SMART (https://arxiv.org/abs/1911.03437) SMART - Apply NCE (Noise Contrastive Estimation) based similarity learning to speed up training of 100M pairs We use pretrained DTE model\"\n }\n },\n {\n \"[doc2]\": {\n \"content\": \"trained on internal data. You can find more details here - Models.md (https://dev.azure.com/TScience/_git/TSciencePublic?path=%2FDualTransformerEncoder%2FMODELS.md&version=GBmaster&_a=preview) Models.md DTE-pretrained for In-context Learning Research suggests that finetuned transformers can be used to retrieve semantically similar exemplars for e.g. KATE (https://arxiv.org/pdf/2101.06804.pdf) KATE . They show that finetuned models esp. tuned on related tasks give the maximum boost to GPT-3 in-context performance. DTE have lot of pretrained models that are trained on intent classification tasks. We can use these model embedding to find natural language utterances which are similar to our test utterances at test time. The steps are: 1. Embed\"\n }\n },\n {\n \"[doc3]\": {\n \"content\": \"train and test utterances using DTE model 2. For each test embedding, find K-nearest neighbors. 3. Prefix the prompt with nearest embeddings. The following diagram from the above paper (https://arxiv.org/pdf/2101.06804.pdf) the above paper visualizes this process: DTE-Finetuned This is an extension of DTE-pretrained method where we further finetune the embedding models for prompt crafting task. In summary, we sample random prompts from our training data and use them for GPT-3 inference for the another part of training data. Some prompts work better and lead to right results whereas other prompts lead\"\n }\n },\n {\n \"[doc4]\": {\n \"content\": \"to wrong completions. We finetune the model on the downstream task of whether a prompt is good or not based on whether it leads to right or wrong completion. This approach is similar to this paper: Learning To Retrieve Prompts for In-Context Learning (https://arxiv.org/pdf/2112.08633.pdf) this paper: Learning To Retrieve Prompts for In-Context Learning . This method is very general but it may require a lot of data to actually finetune a model to learn how to retrieve examples suitable for the downstream inference model like GPT-3.\"\n }\n }\n ]\n}",
Expand Down
13 changes: 6 additions & 7 deletions code/backend/batch/utilities/helpers/env_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dotenv import load_dotenv
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from azure.keyvault.secrets import SecretClient
from .config.conversation_flow import ConversationFlow

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -69,9 +68,13 @@ def __load_config(self, **kwargs) -> None:
self.AZURE_SEARCH_FIELDS_METADATA = os.getenv(
"AZURE_SEARCH_FIELDS_METADATA", "metadata"
)
self.AZURE_SEARCH_SOURCE_COLUMN = os.getenv("AZURE_SEARCH_SOURCE_COLUMN", "source")
self.AZURE_SEARCH_SOURCE_COLUMN = os.getenv(
"AZURE_SEARCH_SOURCE_COLUMN", "source"
)
self.AZURE_SEARCH_CHUNK_COLUMN = os.getenv("AZURE_SEARCH_CHUNK_COLUMN", "chunk")
self.AZURE_SEARCH_OFFSET_COLUMN = os.getenv("AZURE_SEARCH_OFFSET_COLUMN", "offset")
self.AZURE_SEARCH_OFFSET_COLUMN = os.getenv(
"AZURE_SEARCH_OFFSET_COLUMN", "offset"
)
self.AZURE_SEARCH_CONVERSATIONS_LOG_INDEX = os.getenv(
"AZURE_SEARCH_CONVERSATIONS_LOG_INDEX", "conversations"
)
Expand Down Expand Up @@ -211,10 +214,6 @@ def __load_config(self, **kwargs) -> None:
self.ORCHESTRATION_STRATEGY = os.getenv(
"ORCHESTRATION_STRATEGY", "openai_function"
)
# Conversation Type - which chooses between custom or byod
self.CONVERSATION_FLOW = os.getenv(
"CONVERSATION_FLOW", ConversationFlow.CUSTOM.value
)
# Speech Service
self.AZURE_SPEECH_SERVICE_NAME = os.getenv("AZURE_SPEECH_SERVICE_NAME", "")
self.AZURE_SPEECH_SERVICE_REGION = os.getenv("AZURE_SPEECH_SERVICE_REGION")
Expand Down
21 changes: 21 additions & 0 deletions code/backend/pages/04_Configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from batch.utilities.helpers.config.config_helper import ConfigHelper
from azure.core.exceptions import ResourceNotFoundError
from batch.utilities.helpers.config.assistant_strategy import AssistantStrategy
from batch.utilities.helpers.config.conversation_flow import ConversationFlow

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
env_helper: EnvHelper = EnvHelper()
Expand Down Expand Up @@ -65,6 +66,8 @@ def load_css(file_path):
st.session_state["orchestrator_strategy"] = config.orchestrator.strategy.value
if "ai_assistant_type" not in st.session_state:
st.session_state["ai_assistant_type"] = config.prompts.ai_assistant_type
if "conversational_flow" not in st.session_state:
st.session_state["conversational_flow"] = config.prompts.conversational_flow

if env_helper.AZURE_SEARCH_USE_INTEGRATED_VECTORIZATION:
if "max_page_length" not in st.session_state:
Expand Down Expand Up @@ -163,13 +166,30 @@ def validate_documents():


try:
conversational_flow_help = "Whether to use the custom conversational flow or byod conversational flow. Refer to the Conversational flow options README for more details."
with st.expander("Conversational flow configuration", expanded=True):
cols = st.columns([2, 4])
with cols[0]:
conv_flow = st.selectbox(
"Conversational flow",
key="conversational_flow",
options=config.get_available_conversational_flows(),
help=conversational_flow_help,
)

with st.expander("Orchestrator configuration", expanded=True):
cols = st.columns([2, 4])
with cols[0]:
st.selectbox(
"Orchestrator strategy",
key="orchestrator_strategy",
options=config.get_available_orchestration_strategies(),
disabled=(
True
if st.session_state["conversational_flow"]
== ConversationFlow.BYOD.value
else False
),
)

# # # condense_question_prompt_help = "This prompt is used to convert the user's input to a standalone question, using the context of the chat history."
Expand Down Expand Up @@ -377,6 +397,7 @@ def validate_documents():
],
"enable_content_safety": st.session_state["enable_content_safety"],
"ai_assistant_type": st.session_state["ai_assistant_type"],
"conversational_flow": st.session_state["conversational_flow"],
},
"messages": {
"post_answering_filter": st.session_state[
Expand Down
62 changes: 54 additions & 8 deletions code/create_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,67 @@
import mimetypes
from os import path
import sys
import re
import requests
from openai import AzureOpenAI, Stream, APIStatusError
from openai.types.chat import ChatCompletionChunk
from flask import Flask, Response, request, Request, jsonify
from dotenv import load_dotenv
from urllib.parse import quote
from backend.batch.utilities.helpers.env_helper import EnvHelper
from backend.batch.utilities.helpers.orchestrator_helper import Orchestrator
from backend.batch.utilities.helpers.config.config_helper import ConfigHelper
from backend.batch.utilities.helpers.config.conversation_flow import ConversationFlow
from azure.mgmt.cognitiveservices import CognitiveServicesManagementClient
from azure.identity import DefaultAzureCredential
from backend.batch.utilities.helpers.azure_blob_storage_client import (
AzureBlobStorageClient,
)

ERROR_429_MESSAGE = "We're currently experiencing a high number of requests for the service you're trying to access. Please wait a moment and try again."
ERROR_GENERIC_MESSAGE = "An error occurred. Please try again. If the problem persists, please contact the site administrator."
logger = logging.getLogger(__name__)


def get_markdown_url(source, title, container_sas):
"""Get Markdown URL for a citation"""

url = quote(source, safe=":/")
if "_SAS_TOKEN_PLACEHOLDER_" in url:
url = url.replace("_SAS_TOKEN_PLACEHOLDER_", container_sas)
return f"[{title}]({url})"


def get_citations(citation_list):
"""Returns Formated Citations"""
blob_client = AzureBlobStorageClient()
container_sas = blob_client.get_container_sas()
citations_dict = {"citations": []}
for citation in citation_list.get("citations"):
metadata = (
json.loads(citation["url"])
if isinstance(citation["url"], str)
else citation["url"]
)
title = citation["title"]
url = get_markdown_url(metadata["source"], title, container_sas)
citations_dict["citations"].append(
{
"content": url + "\n\n\n" + citation["content"],
"id": metadata["id"],
"chunk_id": (
re.findall(r"\d+", metadata["chunk_id"])[-1]
if metadata["chunk_id"] is not None
else metadata["chunk"]
),
"title": title,
"filepath": title.split("/")[-1],
"url": url,
}
)
return citations_dict


def stream_with_data(response: Stream[ChatCompletionChunk]):
"""This function streams the response from Azure OpenAI with data."""
response_obj = {
Expand Down Expand Up @@ -67,8 +111,9 @@ def stream_with_data(response: Stream[ChatCompletionChunk]):
role = delta.role

if role == "assistant":
citations = get_citations(delta.model_extra["context"])
response_obj["choices"][0]["messages"][0]["content"] = json.dumps(
delta.model_extra["context"],
citations,
ensure_ascii=False,
)
else:
Expand Down Expand Up @@ -135,7 +180,8 @@ def conversation_with_data(conversation: Request, env_helper: EnvHelper):
env_helper.AZURE_SEARCH_CONTENT_VECTOR_COLUMN
],
"title_field": env_helper.AZURE_SEARCH_TITLE_COLUMN or None,
"url_field": env_helper.AZURE_SEARCH_URL_COLUMN or None,
"url_field": env_helper.AZURE_SEARCH_FIELDS_METADATA
or None,
"filepath_field": (
env_helper.AZURE_SEARCH_FILENAME_COLUMN or None
),
Expand Down Expand Up @@ -166,6 +212,7 @@ def conversation_with_data(conversation: Request, env_helper: EnvHelper):
)

if not env_helper.SHOULD_STREAM:
citations = get_citations(response.choices[0].message.model_extra["context"])
response_obj = {
"id": response.id,
"model": response.model,
Expand All @@ -176,7 +223,7 @@ def conversation_with_data(conversation: Request, env_helper: EnvHelper):
"messages": [
{
"content": json.dumps(
response.choices[0].message.model_extra["context"],
citations,
ensure_ascii=False,
),
"end_turn": False,
Expand All @@ -194,10 +241,7 @@ def conversation_with_data(conversation: Request, env_helper: EnvHelper):

return response_obj

return Response(
stream_with_data(response),
mimetype="application/json-lines",
)
return Response(stream_with_data(response), mimetype="application/json-lines")


def stream_without_data(response: Stream[ChatCompletionChunk]):
Expand Down Expand Up @@ -405,7 +449,9 @@ async def conversation_custom():

@app.route("/api/conversation", methods=["POST"])
async def conversation():
conversation_flow = env_helper.CONVERSATION_FLOW
ConfigHelper.get_active_config_or_default.cache_clear()
result = ConfigHelper.get_active_config_or_default()
conversation_flow = result.prompts.conversational_flow
if conversation_flow == ConversationFlow.CUSTOM.value:
return await conversation_custom()
elif conversation_flow == ConversationFlow.BYOD.value:
Expand Down
2 changes: 1 addition & 1 deletion code/frontend/src/components/Answer/AnswerParser.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ let filteredCitations = [] as Citation[];

// Define a function to check if a citation with the same Chunk_Id already exists in filteredCitations
const isDuplicate = (citation: Citation,citationIndex:string) => {
return filteredCitations.some((c) => c.chunk_id === citation.chunk_id) && !filteredCitations.find((c) => c.id === citationIndex) ;
return filteredCitations.some((c) => c.chunk_id === citation.chunk_id) ;
};

export function parseAnswer(answer: AskResponse): ParsedAnswer {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
import pytest
from pytest_httpserver import HTTPServer
from unittest.mock import patch
import requests

from tests.request_matching import (
Expand Down Expand Up @@ -176,7 +177,9 @@ def test_post_makes_correct_calls_to_openai_embeddings_to_embed_question_to_sear


def test_post_makes_correct_calls_to_openai_embeddings_to_embed_question_to_store_in_conversation_log(
app_url: str, app_config: AppConfig, httpserver: HTTPServer
app_url: str,
app_config: AppConfig,
httpserver: HTTPServer,
):
# when
requests.post(f"{app_url}{path}", json=body)
Expand Down Expand Up @@ -649,9 +652,15 @@ def test_post_makes_correct_call_to_store_conversation_in_search(
)


@patch(
"backend.batch.utilities.helpers.config.config_helper.ConfigHelper.get_active_config_or_default"
)
def test_post_returns_error_when_downstream_fails(
app_url: str, app_config: AppConfig, httpserver: HTTPServer
get_active_config_or_default_mock, app_url: str, httpserver: HTTPServer
):
get_active_config_or_default_mock.return_value.prompts.conversational_flow = (
"custom"
)
httpserver.expect_oneshot_request(
re.compile(".*"),
).respond_with_json({}, status=403)
Expand Down
Loading

0 comments on commit 6998ab0

Please sign in to comment.