Skip to content
This repository has been archived by the owner on Apr 29, 2024. It is now read-only.

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
sasha370 committed Apr 12, 2024
1 parent f862b33 commit 6e77138
Show file tree
Hide file tree
Showing 13 changed files with 65 additions and 82 deletions.
1 change: 1 addition & 0 deletions api/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def create_embeds(EmbedRequest: EmbedRequest):
"""
Endpoint to initiate the embedding generation and storage process in the background.
:param EmbedRequest:
:return:
"""
page_id = EmbedRequest.page_id
thread = threading.Thread(target=vector.pages.generate_one_embedding_to_database, args=(page_id,))
Expand Down
9 changes: 4 additions & 5 deletions api/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import logging


def post_request(url, payload, headers=None, data_type=None):
def post_request(url, payload, headers=None):
"""
Post a request to a given URL.
:param data_type: Type of payload data.
:param url: The URL to post to.
:param payload: The payload to send.
:param headers: The headers to send.
Expand All @@ -16,8 +15,8 @@ def post_request(url, payload, headers=None, data_type=None):
try:
response = requests.post(url, json=payload, headers=headers)
response.raise_for_status()
logging.info(f"INFO: {data_type} request submitted to {url}")
logging.info(f"INFO: request submitted to {url}")
except requests.exceptions.HTTPError as e:
logging.error(f"ERROR: An HTTP error with {data_type} request occurred: {e}")
logging.error(f"ERROR: An HTTP error occurred: {e}")
except Exception as e:
logging.error(f"ERROR: An error with {data_type} request occurred: {e}")
logging.error(f"ERROR: An error occurred: {e}")
4 changes: 2 additions & 2 deletions confluence/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging

from database.page_manager import PageManager
from database.space_manager import SpaceManager
from database.space_manager import upsert_space_info
import vector.pages

from .client import ConfluenceClient
Expand Down Expand Up @@ -32,7 +32,7 @@ def import_space(space_key, space_name, session):

vector.pages.generate_missing_embeddings_to_database(session)

SpaceManager().upsert_space_info(
upsert_space_info(
session,
space_key=space_key,
space_name=space_name,
Expand Down
42 changes: 9 additions & 33 deletions database/space_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,18 @@
from models.space_info import SpaceInfo


class SpaceManager:
def __init__(self):
pass
def upsert_space_info(session, space_key, space_name, last_import_date):
"""Create or update space information based on the existence of the space key."""
existing_space = session.query(SpaceInfo).filter_by(space_key=space_key).first()
last_import_date_formatted = datetime.strptime(last_import_date, '%Y-%m-%d %H:%M:%S')

def add_space_info(self, space_key, space_name, last_import_date, session):
"""Add a new space to the database."""
if existing_space:
existing_space.last_import_date = last_import_date_formatted
else:
new_space = SpaceInfo(
space_key=space_key,
space_name=space_name,
last_import_date=datetime.strptime(last_import_date, '%Y-%m-%d %H:%M:%S')
last_import_date=last_import_date_formatted
)
session.add(new_space)
session.commit()

def update_space_info(self, space_key, last_import_date, session):
"""Update the last import date of an existing space."""
space = session.query(SpaceInfo).filter_by(space_key=space_key).first()

if space:
space.last_import_date = datetime.strptime(last_import_date, '%Y-%m-%d %H:%M:%S')
session.commit()
else:
print(f"Space with key {space_key} not found.")

def upsert_space_info(self, session, space_key, space_name, last_import_date):
"""Create or update space information based on the existence of the space key."""
existing_space = session.query(SpaceInfo).filter_by(space_key=space_key).first()
last_import_date_formatted = datetime.strptime(last_import_date, '%Y-%m-%d %H:%M:%S')

if existing_space:
existing_space.last_import_date = last_import_date_formatted
else:
new_space = SpaceInfo(
space_key=space_key,
space_name=space_name,
last_import_date=last_import_date_formatted
)
session.add(new_space)
session.commit()
session.commit()
4 changes: 2 additions & 2 deletions interactions/identify_knowledge_gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def query_assistant_with_context(context, formatted_interactions, thread_id=None
return assistant_response, thread_id


def process_and_store_questions(assistant_response_json, db_session):
def process_and_store_questions(assistant_response_json, session):
"""
Processes the JSON response from the assistant, extracts questions, stores them in the database,
and collects the QuizQuestionDTO objects.
Expand All @@ -147,7 +147,7 @@ def process_and_store_questions(assistant_response_json, db_session):
logging.error(f"Error decoding JSON: {e}")
return []

quiz_question_manager = QuizQuestionManager(db_session)
quiz_question_manager = QuizQuestionManager(session)

quiz_question_dtos = []
for item in questions_data:
Expand Down
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from database.database import get_db_session


def answer_question_with_assistant(question, db_session):
def answer_question_with_assistant(question, session):
page_ids = vector.pages.retrieve_relevant_ids(question, count=question_context_pages_count)
response, thread_id = query_assistant_with_context(question, page_ids, db_session)
response, thread_id = query_assistant_with_context(question, page_ids, session)
return response, thread_id


Expand Down
2 changes: 1 addition & 1 deletion slack/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def start(self):
def load_slack_bot():
logging.basicConfig(level=logging.INFO)
bot_user_id = get_bot_user_id(slack_bot_user_oauth_token)
event_handlers = [ChannelMessageHandler(), ]
event_handlers = [ChannelMessageHandler()]
bot = SlackBot(slack_bot_user_oauth_token, slack_app_level_token, bot_user_id, event_handlers)
bot.start()

Expand Down
4 changes: 2 additions & 2 deletions slack/channel_message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def handle(self, client: SocketModeClient, req: SocketModeRequest, web_client: W
"channel": channel,
"user": user_id
}
post_request(questions_endpoint, payload, data_type='Question')
post_request(questions_endpoint, payload)

# Identify and handle feedback
elif thread_ts in self.questions: # Message is a reply to a question
Expand All @@ -129,7 +129,7 @@ def handle(self, client: SocketModeClient, req: SocketModeRequest, web_client: W
"user": user_id,
"parent_question": parent_question
}
post_request(feedback_endpoint, payload, data_type='Feedback')
post_request(feedback_endpoint, payload)

else:
logging.info(f"Skipping message with ID {ts} from user {user_id}. Reason: {skip_reason}")
Expand Down
65 changes: 35 additions & 30 deletions slack/event_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,34 @@ def get_user_name_from_id(slack_web_client, user_id):


class EventConsumer:
def __init__(self, session):
self.session = session
def __init__(self):
self.web_client = WebClient(token=slack_bot_user_oauth_token)
self.interaction_manager = QAInteractionManager(session)
logging.log(logging.DEBUG, f"Slack Event Consumer initiated successfully")

def add_question_and_response_to_database(self, question_event, response_text, assistant_thread_id):
self.interaction_manager.add_question_and_answer(question=question_event.text,
answer=response_text,
thread_id=question_event.ts,
assistant_thread_id=assistant_thread_id,
channel_id=question_event.channel,
question_ts=datetime.fromtimestamp(float(question_event.ts)),
answer_ts=datetime.now(),
slack_user_id=question_event.user)

print(f"\n\nQuestion and answer stored in the database: question: {question_event.dict()},\nAnswer: {response_text},\nAssistant_id {assistant_thread_id}\n\n")

def process_question(self, question_event: QuestionEvent): # TODO: Refactor this method
def add_question_and_response_to_database(self, session, question_event, response_text, assistant_thread_id):
QAInteractionManager(session).add_question_and_answer(question=question_event.text,
answer=response_text,
thread_id=question_event.ts,
assistant_thread_id=assistant_thread_id,
channel_id=question_event.channel,
question_ts=datetime.fromtimestamp(
float(question_event.ts)),
answer_ts=datetime.now(),
slack_user_id=question_event.user)

print(
f"\n\nQuestion and answer stored in the database: question: {question_event.dict()},\nAnswer: {response_text},\nAssistant_id {assistant_thread_id}\n\n")

def process_question(self, session, question_event: QuestionEvent): # TODO: Refactor this method
channel_id = question_event.channel
message_ts = question_event.ts

try:
context_page_ids = vector.pages.retrieve_relevant_ids(question_event.text, count=question_context_pages_count)
context_page_ids = vector.pages.retrieve_relevant_ids(question_event.text,
count=question_context_pages_count)
response_text, assistant_thread_id = query_assistant_with_context(question_event.text,
context_page_ids,
self.session,
session,
None)
except Exception as e:
print(f"Error processing question: {e}")
Expand All @@ -79,33 +81,36 @@ def process_question(self, question_event: QuestionEvent): # TODO: Refactor thi
if response_text:
print(f"Response from assistant: {response_text}\n")
try:
self.add_question_and_response_to_database(question_event,
self.add_question_and_response_to_database(session,
question_event,
response_text,
assistant_thread_id=assistant_thread_id)
try:
ScoreManager(self.session).add_or_update_score(slack_user_id=question_event.user, category='seeker')
ScoreManager(session).add_or_update_score(slack_user_id=question_event.user, category='seeker')
print(f"Score updated for user {question_event.user}")
except Exception as e:
print(f"Error updating score for user {question_event.user}: {e}")
self.web_client.chat_postMessage(channel=channel_id, text=response_text, thread_ts=message_ts)
print(f"\nResponse posted to Slack thread: {message_ts}\n")
except Exception as e:
print(f"Error registering message as processed, adding to db and responding to the question on slack: {e}")
print(
f"Error registering message as processed, adding to db and responding to the question on slack: {e}")

def generate_extended_context_query(self, existing_interaction, feedback_text):
extended_context_query = ""
if existing_interaction:
extended_context_query = f"Follow up: {feedback_text}, Initial question: {existing_interaction.question_text}, Initial answer: {existing_interaction.answer_text}"
return extended_context_query

def process_feedback(self, feedback_event: FeedbackEvent): # TODO: Refactor this method
def process_feedback(self, session, feedback_event: FeedbackEvent): # TODO: Refactor this method
channel_id = feedback_event.channel
message_ts = feedback_event.ts
thread_ts = feedback_event.thread_ts
response_text = None
interaction_manager = QAInteractionManager(session)

try:
existing_interaction = self.interaction_manager.get_interaction_by_thread_id(thread_ts)
existing_interaction = interaction_manager.get_interaction_by_thread_id(thread_ts)
assistant_thread_id = existing_interaction.assistant_thread_id if existing_interaction else None
print(f"\n\nExisting interaction found: {existing_interaction}\n\n")
except Exception as e:
Expand All @@ -117,32 +122,32 @@ def process_feedback(self, feedback_event: FeedbackEvent): # TODO: Refactor thi
print(f"\n\nExtended context: {extended_context_query}\n\n")
page_ids = vector.pages.retrieve_relevant_ids(extended_context_query, count=question_context_pages_count)
try:
response_text, assistant_thread_id = query_assistant_with_context(feedback_event.text, page_ids, assistant_thread_id)
response_text, assistant_thread_id = query_assistant_with_context(feedback_event.text, page_ids,
assistant_thread_id)
except Exception as e:
print(f"Error processing feedback: {e}")
response_text = None

if response_text:
print(f"Response from assistant: {response_text}\n")
timestamp_str = datetime.now().isoformat()
comment = {"text": feedback_event.text, "user": feedback_event.user, "timestamp": timestamp_str, "assistant response": response_text}
self.interaction_manager.add_comment_to_interaction(thread_id=thread_ts, comment=comment)
comment = {"text": feedback_event.text, "user": feedback_event.user, "timestamp": timestamp_str,
"assistant response": response_text}
interaction_manager.add_comment_to_interaction(thread_id=thread_ts, comment=comment)
print(f"Feedback appended to the interaction in the database: {feedback_event.dict()}\n")
self.web_client.chat_postMessage(channel=channel_id, text=response_text, thread_ts=thread_ts)
print(f"Feedback response posted to Slack thread: {message_ts}\n")
else:
print(f"No response generated for feedback: {feedback_event.dict()}\n")


# TODO: Move session call to method level
def process_question(question_event: QuestionEvent):
"""Directly processes a question event without using the queue."""
with get_db_session() as session:
EventConsumer(session).process_question(question_event)
EventConsumer().process_question(session, question_event)


# TODO: Move session call to method level
def process_feedback(feedback_event: FeedbackEvent):
"""Directly processes a feedback event without using the queue."""
with get_db_session() as session:
EventConsumer(session).process_feedback(feedback_event)
EventConsumer().process_feedback(session, feedback_event)
6 changes: 3 additions & 3 deletions slack/message_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from slack_sdk.errors import SlackApiError


def post_questions_to_slack(db_session, channel_id, quiz_question_dtos, user_ids):
def post_questions_to_slack(session, channel_id, quiz_question_dtos, user_ids):
"""
Posts a list of QuizQuestionDTO objects to a specified Slack channel, tags all users in the first reply,
invites them to contribute to the information gathering related to their questions, asks them to tag domain
Expand All @@ -23,8 +23,8 @@ def post_questions_to_slack(db_session, channel_id, quiz_question_dtos, user_ids
"""

client = WebClient(token=slack_bot_user_oauth_token)
quiz_question_manager = QuizQuestionManager(db_session)
score_manager = ScoreManager(db_session)
quiz_question_manager = QuizQuestionManager(session)
score_manager = ScoreManager(session)

for quiz_question_dto in quiz_question_dtos:
try:
Expand Down
2 changes: 1 addition & 1 deletion vector/interactions/embeddings/generate_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def submit_request(interaction_id):
Submit an API request to create interaction embeds.
:return: None
"""
post_request(interaction_embeds_endpoint, {"interaction_id": interaction_id}, data_type='Interaction embed')
post_request(interaction_embeds_endpoint, {"interaction_id": interaction_id})


def generate_missing_embeddings_to_database(session, retry_limit: int = 3, wait_time: int = 5) -> None:
Expand Down
2 changes: 1 addition & 1 deletion vector/pages/embeddings/generate_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def submit_embedding_creation_request(page_id: str):
post_request(embeds_endpoint, {"page_id": page_id}, data_type='Embed')
post_request(embeds_endpoint, {"page_id": page_id})


def generate_missing_embeddings_to_database(session, retry_limit: int = 3, wait_time: int = 5) -> None:
Expand Down
2 changes: 2 additions & 0 deletions vector/pages/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
def retrieve_relevant_ids(question: str, count: int) -> List[int]:
"""
Retrieve the most relevant documents for a given question using the vector database.
Args:
question (str): The question to retrieve relevant documents for.
Returns:
List[str]: A list of document IDs of the most relevant documents.
"""
Expand Down

0 comments on commit 6e77138

Please sign in to comment.