Skip to content

Commit

Permalink
fix: fixed lint issue (#1396)
Browse files Browse the repository at this point in the history
  • Loading branch information
AjitPadhi-Microsoft authored Oct 10, 2024
1 parent c158dd3 commit 28e628a
Show file tree
Hide file tree
Showing 12 changed files with 168 additions and 148 deletions.
24 changes: 12 additions & 12 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,16 @@ jobs:
fi
echo "MIN_COVERAGE=$MIN_COVERAGE" >> "$GITHUB_OUTPUT"
# - name: Run Python Tests
# run: make python-test optional_args="--junitxml=coverage-junit.xml --cov=. --cov-report xml:coverage.xml --cov-fail-under ${{ steps.coverage-value.outputs.MIN_COVERAGE }}"
# - uses: actions/upload-artifact@v4
# if: ${{ !cancelled() }}
# with:
# name: coverage
# path: |
# coverage-junit.xml
# coverage.xml
# if-no-files-found: error
- name: Run Python Tests
run: make python-test optional_args="--junitxml=coverage-junit.xml --cov=. --cov-report xml:coverage.xml --cov-fail-under ${{ steps.coverage-value.outputs.MIN_COVERAGE }}"
- uses: actions/upload-artifact@v4
if: ${{ !cancelled() }}
with:
name: coverage
path: |
coverage-junit.xml
coverage.xml
if-no-files-found: error
- name: Setup node
uses: actions/setup-node@v4
with:
Expand All @@ -95,5 +95,5 @@ jobs:
cache-dependency-path: "code/frontend/package-lock.json"
- name: Run frontend unit tests
run: make unittest-frontend
# - name: Lint
# run: make lint
- name: Lint
run: make lint
31 changes: 17 additions & 14 deletions code/backend/batch/utilities/chat_history/auth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,42 @@
import json
import logging


def get_authenticated_user_details(request_headers):
user_object = {}

## check the headers for the Principal-Id (the guid of the signed in user)
# check the headers for the Principal-Id (the guid of the signed in user)
if "X-Ms-Client-Principal-Id" not in request_headers.keys():
## if it's not, assume we're in development mode and return a default user
# if it's not, assume we're in development mode and return a default user
from . import sample_user

raw_user_object = sample_user.sample_user
else:
## if it is, get the user details from the EasyAuth headers
raw_user_object = {k:v for k,v in request_headers.items()}
# if it is, get the user details from the EasyAuth headers
raw_user_object = {k: v for k, v in request_headers.items()}

user_object['user_principal_id'] = raw_user_object.get('X-Ms-Client-Principal-Id')
user_object['user_name'] = raw_user_object.get('X-Ms-Client-Principal-Name')
user_object['auth_provider'] = raw_user_object.get('X-Ms-Client-Principal-Idp')
user_object['auth_token'] = raw_user_object.get('X-Ms-Token-Aad-Id-Token')
user_object['client_principal_b64'] = raw_user_object.get('X-Ms-Client-Principal')
user_object['aad_id_token'] = raw_user_object.get('X-Ms-Token-Aad-Id-Token')
user_object["user_principal_id"] = raw_user_object.get("X-Ms-Client-Principal-Id")
user_object["user_name"] = raw_user_object.get("X-Ms-Client-Principal-Name")
user_object["auth_provider"] = raw_user_object.get("X-Ms-Client-Principal-Idp")
user_object["auth_token"] = raw_user_object.get("X-Ms-Token-Aad-Id-Token")
user_object["client_principal_b64"] = raw_user_object.get("X-Ms-Client-Principal")
user_object["aad_id_token"] = raw_user_object.get("X-Ms-Token-Aad-Id-Token")

return user_object


def get_tenantid(client_principal_b64):
logger = logging.getLogger(__name__)
tenant_id = ''
tenant_id = ""
if client_principal_b64:
try:
# Decode the base64 header to get the JSON string
decoded_bytes = base64.b64decode(client_principal_b64)
decoded_string = decoded_bytes.decode('utf-8')
decoded_string = decoded_bytes.decode("utf-8")
# Convert the JSON string1into a Python dictionary
user_info = json.loads(decoded_string)
# Extract the tenant ID
tenant_id = user_info.get('tid') # 'tid' typically holds the tenant ID
tenant_id = user_info.get("tid") # 'tid' typically holds the tenant ID
except Exception as ex:
logger.exception(ex)
logger.exception(ex)
return tenant_id
138 changes: 76 additions & 62 deletions code/backend/batch/utilities/chat_history/cosmosdb.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,65 @@
import uuid
from datetime import datetime
from azure.cosmos.aio import CosmosClient
from azure.cosmos import exceptions

class CosmosConversationClient():

def __init__(self, cosmosdb_endpoint: str, credential: any, database_name: str, container_name: str, enable_message_feedback: bool = False):
class CosmosConversationClient:

def __init__(
self,
cosmosdb_endpoint: str,
credential: any,
database_name: str,
container_name: str,
enable_message_feedback: bool = False,
):
self.cosmosdb_endpoint = cosmosdb_endpoint
self.credential = credential
self.database_name = database_name
self.container_name = container_name
self.enable_message_feedback = enable_message_feedback
try:
self.cosmosdb_client = CosmosClient(self.cosmosdb_endpoint, credential=credential)
self.cosmosdb_client = CosmosClient(
self.cosmosdb_endpoint, credential=credential
)
except exceptions.CosmosHttpResponseError as e:
if e.status_code == 401:
raise ValueError("Invalid credentials") from e
else:
raise ValueError("Invalid CosmosDB endpoint") from e

try:
self.database_client = self.cosmosdb_client.get_database_client(database_name)
self.database_client = self.cosmosdb_client.get_database_client(
database_name
)
except exceptions.CosmosResourceNotFoundError:
raise ValueError("Invalid CosmosDB database name")

try:
self.container_client = self.database_client.get_container_client(container_name)
self.container_client = self.database_client.get_container_client(
container_name
)
except exceptions.CosmosResourceNotFoundError:
raise ValueError("Invalid CosmosDB container name")


async def ensure(self):
if not self.cosmosdb_client or not self.database_client or not self.container_client:
if (
not self.cosmosdb_client
or not self.database_client
or not self.container_client
):
return False, "CosmosDB client not initialized correctly"
try:
database_info = await self.database_client.read()
except:
return False, f"CosmosDB database {self.database_name} on account {self.cosmosdb_endpoint} not found"
await self.database_client.read()
except Exception:
return (
False,
f"CosmosDB database {self.database_name} on account {self.cosmosdb_endpoint} not found",
)

try:
container_info = await self.container_client.read()
except:
await self.container_client.read()
except Exception:
return False, f"CosmosDB container {self.container_name} not found"

return True, "CosmosDB client initialized successfully"
Expand All @@ -55,7 +74,7 @@ async def create_conversation(self, user_id, conversation_id, title=""):
"title": title,
"conversationId": conversation_id,
}
## TODO: add some error handling based on the output of the upsert_item call
# TODO: add some error handling based on the output of the upsert_item call
resp = await self.container_client.upsert_item(conversation)
if resp:
return resp
Expand All @@ -70,114 +89,109 @@ async def upsert_conversation(self, conversation):
return False

async def delete_conversation(self, user_id, conversation_id):
conversation = await self.container_client.read_item(item=conversation_id, partition_key=user_id)
conversation = await self.container_client.read_item(
item=conversation_id, partition_key=user_id
)
if conversation:
resp = await self.container_client.delete_item(item=conversation_id, partition_key=user_id)
resp = await self.container_client.delete_item(
item=conversation_id, partition_key=user_id
)
return resp
else:
return True


async def delete_messages(self, conversation_id, user_id):
## get a list of all the messages in the conversation
# get a list of all the messages in the conversation
messages = await self.get_messages(user_id, conversation_id)
response_list = []
if messages:
for message in messages:
resp = await self.container_client.delete_item(item=message['id'], partition_key=user_id)
resp = await self.container_client.delete_item(
item=message["id"], partition_key=user_id
)
response_list.append(resp)
return response_list


async def get_conversations(self, user_id, limit, sort_order = 'DESC', offset = 0):
parameters = [
{
'name': '@userId',
'value': user_id
}
]
async def get_conversations(self, user_id, limit, sort_order="DESC", offset=0):
parameters = [{"name": "@userId", "value": user_id}]
query = f"SELECT * FROM c where c.userId = @userId and c.type='conversation' order by c.updatedAt {sort_order}"
if limit is not None:
query += f" offset {offset} limit {limit}"

conversations = []
async for item in self.container_client.query_items(query=query, parameters=parameters):
async for item in self.container_client.query_items(
query=query, parameters=parameters
):
conversations.append(item)

return conversations

async def get_conversation(self, user_id, conversation_id):
parameters = [
{
'name': '@conversationId',
'value': conversation_id
},
{
'name': '@userId',
'value': user_id
}
{"name": "@conversationId", "value": conversation_id},
{"name": "@userId", "value": user_id},
]
query = f"SELECT * FROM c where c.id = @conversationId and c.type='conversation' and c.userId = @userId"
query = "SELECT * FROM c where c.id = @conversationId and c.type='conversation' and c.userId = @userId"
conversations = []
async for item in self.container_client.query_items(query=query, parameters=parameters):
async for item in self.container_client.query_items(
query=query, parameters=parameters
):
conversations.append(item)

## if no conversations are found, return None
# if no conversations are found, return None
if len(conversations) == 0:
return None
else:
return conversations[0]

async def create_message(self, uuid, conversation_id, user_id, input_message: dict):
message = {
'id': uuid,
'type': 'message',
'userId' : user_id,
'createdAt': datetime.utcnow().isoformat(),
'updatedAt': datetime.utcnow().isoformat(),
'conversationId' : conversation_id,
'role': input_message['role'],
'content': input_message['content']
"id": uuid,
"type": "message",
"userId": user_id,
"createdAt": datetime.utcnow().isoformat(),
"updatedAt": datetime.utcnow().isoformat(),
"conversationId": conversation_id,
"role": input_message["role"],
"content": input_message["content"],
}

if self.enable_message_feedback:
message['feedback'] = ''
message["feedback"] = ""

resp = await self.container_client.upsert_item(message)
if resp:
## update the parent conversations's updatedAt field with the current message's createdAt datetime value
# update the parent conversations's updatedAt field with the current message's createdAt datetime value
conversation = await self.get_conversation(user_id, conversation_id)
if not conversation:
return "Conversation not found"
conversation['updatedAt'] = message['createdAt']
conversation["updatedAt"] = message["createdAt"]
await self.upsert_conversation(conversation)
return resp
else:
return False

async def update_message_feedback(self, user_id, message_id, feedback):
message = await self.container_client.read_item(item=message_id, partition_key=user_id)
message = await self.container_client.read_item(
item=message_id, partition_key=user_id
)
if message:
message['feedback'] = feedback
message["feedback"] = feedback
resp = await self.container_client.upsert_item(message)
return resp
else:
return False

async def get_messages(self, user_id, conversation_id):
parameters = [
{
'name': '@conversationId',
'value': conversation_id
},
{
'name': '@userId',
'value': user_id
}
{"name": "@conversationId", "value": conversation_id},
{"name": "@userId", "value": user_id},
]
query = f"SELECT * FROM c WHERE c.conversationId = @conversationId AND c.type='message' AND c.userId = @userId ORDER BY c.timestamp ASC"
query = "SELECT * FROM c WHERE c.conversationId = @conversationId AND c.type='message' AND c.userId = @userId ORDER BY c.timestamp ASC"
messages = []
async for item in self.container_client.query_items(query=query, parameters=parameters):
async for item in self.container_client.query_items(
query=query, parameters=parameters
):
messages.append(item)

return messages
Loading

0 comments on commit 28e628a

Please sign in to comment.