Skip to content

Commit

Permalink
fix: Encryption key error handling fix (#567)
Browse files Browse the repository at this point in the history
* Encryption key error handling fix, platform service error handling and logging improvement

* Minor typehint fix

* Fix for subscription expiry behaviour

* Fixed merge related conflicts
  • Loading branch information
chandrasekharan-zipstack authored Aug 20, 2024
1 parent ae45e2a commit b3e3f4b
Show file tree
Hide file tree
Showing 15 changed files with 206 additions and 64 deletions.
8 changes: 0 additions & 8 deletions backend/adapter_processor/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,6 @@ class InValidAdapterId(APIException):
default_detail = "Adapter ID is not Valid."


class InvalidEncryptionKey(APIException):
status_code = 403
default_detail = (
"Platform encryption key for storing adapter credentials has changed! "
"Please inform the organization admin to contact support."
)


class InternalServiceError(APIException):
status_code = 500
default_detail = "Internal Service error"
Expand Down
23 changes: 13 additions & 10 deletions backend/adapter_processor/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from typing import Any

from account.models import User
from cryptography.fernet import Fernet
from cryptography.fernet import Fernet, InvalidToken
from django.conf import settings
from django.db import models
from django.db.models import QuerySet
from unstract.sdk.adapters.adapterkit import Adapterkit
from unstract.sdk.adapters.enums import AdapterTypes
from unstract.sdk.adapters.exceptions import AdapterError
from utils.exceptions import InvalidEncryptionKey
from utils.models.base_model import BaseModel

ADAPTER_NAME_SIZE = 128
Expand Down Expand Up @@ -135,22 +136,24 @@ def create_adapter(self) -> None:

self.save()

def get_adapter_meta_data(self) -> Any:
encryption_secret: str = settings.ENCRYPTION_KEY
f: Fernet = Fernet(encryption_secret.encode("utf-8"))
@property
def metadata(self) -> Any:
try:
encryption_secret: str = settings.ENCRYPTION_KEY
f: Fernet = Fernet(encryption_secret.encode("utf-8"))

adapter_metadata = json.loads(
f.decrypt(bytes(self.adapter_metadata_b).decode("utf-8"))
)
adapter_metadata = json.loads(
f.decrypt(bytes(self.adapter_metadata_b).decode("utf-8"))
)
except InvalidToken:
raise InvalidEncryptionKey(entity=InvalidEncryptionKey.Entity.ADAPTER)
return adapter_metadata

def get_context_window_size(self) -> int:

adapter_metadata = self.get_adapter_meta_data()
# Get the adapter_instance
adapter_class = Adapterkit().get_adapter_class_by_adapter_id(self.adapter_id)
try:
adapter_instance = adapter_class(adapter_metadata)
adapter_instance = adapter_class(self.metadata)
return adapter_instance.get_context_window_size()
except AdapterError as e:
logger.warning(f"Unable to retrieve context window size - {e}")
Expand Down
12 changes: 3 additions & 9 deletions backend/adapter_processor/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from account.serializer import UserSerializer
from adapter_processor.adapter_processor import AdapterProcessor
from adapter_processor.constants import AdapterKeys
from adapter_processor.exceptions import InvalidEncryptionKey
from cryptography.fernet import Fernet, InvalidToken
from cryptography.fernet import Fernet
from django.conf import settings
from rest_framework import serializers
from rest_framework.serializers import ModelSerializer
Expand Down Expand Up @@ -62,11 +61,7 @@ def to_representation(self, instance: AdapterInstance) -> dict[str, str]:
rep: dict[str, str] = super().to_representation(instance)

rep.pop(AdapterKeys.ADAPTER_METADATA_B)

try:
adapter_metadata = instance.get_adapter_meta_data()
except InvalidToken:
raise InvalidEncryptionKey
adapter_metadata = instance.metadata
rep[AdapterKeys.ADAPTER_METADATA] = adapter_metadata
# Retrieve context window if adapter is a LLM
# For other adapter types, context_window is not relevant.
Expand Down Expand Up @@ -124,8 +119,7 @@ def to_representation(self, instance: AdapterInstance) -> dict[str, str]:
rep[common.ICON] = AdapterProcessor.get_adapter_data_with_key(
instance.adapter_id, common.ICON
)
adapter_metadata = instance.get_adapter_meta_data()
model = adapter_metadata.get("model")
model = instance.metadata.get("model")
if model:
rep["model"] = model

Expand Down
16 changes: 10 additions & 6 deletions backend/connector/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from connector_auth.models import ConnectorAuth
from connector_processor.connector_processor import ConnectorProcessor
from connector_processor.constants import ConnectorKeys
from cryptography.fernet import Fernet
from cryptography.fernet import Fernet, InvalidToken
from django.conf import settings
from django.db import models
from project.models import Project
from utils.exceptions import InvalidEncryptionKey
from utils.models.base_model import BaseModel
from workflow_manager.workflow.models import Workflow

Expand Down Expand Up @@ -111,11 +112,14 @@ def __str__(self) -> str:

@property
def metadata(self) -> Any:
encryption_secret: str = settings.ENCRYPTION_KEY
cipher_suite: Fernet = Fernet(encryption_secret.encode("utf-8"))
decrypted_value = cipher_suite.decrypt(
bytes(self.connector_metadata_b).decode("utf-8")
)
try:
encryption_secret: str = settings.ENCRYPTION_KEY
cipher_suite: Fernet = Fernet(encryption_secret.encode("utf-8"))
decrypted_value = cipher_suite.decrypt(
bytes(self.connector_metadata_b).decode("utf-8")
)
except InvalidToken:
raise InvalidEncryptionKey(entity=InvalidEncryptionKey.Entity.CONNECTOR)
return json.loads(decrypted_value)

class Meta:
Expand Down
7 changes: 1 addition & 6 deletions backend/connector/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,8 @@ def to_representation(self, instance: ConnectorInstance) -> dict[str, str]:
rep[ConnectorKeys.ICON] = ConnectorProcessor.get_connector_data_with_key(
instance.connector_id, ConnectorKeys.ICON
)
encryption_secret: str = settings.ENCRYPTION_KEY
f: Fernet = Fernet(encryption_secret.encode("utf-8"))

rep.pop(CIKey.CONNECTOR_METADATA_B)
if instance.connector_metadata_b:
adapter_metadata = json.loads(
f.decrypt(bytes(instance.connector_metadata_b).decode("utf-8"))
)
rep[CIKey.CONNECTOR_METADATA] = adapter_metadata
rep[CIKey.CONNECTOR_METADATA] = instance.metadata
return rep
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ def _fetch_single_pass_response(
if answer["status"] == "ERROR":
error_message = answer.get("error", None)
raise AnswerFetchError(
f"Error while fetching response for prompt. {error_message}"
f"Error while fetching response for prompt(s). {error_message}"
)
output_response = json.loads(answer["structure_output"])
return output_response
Expand Down
34 changes: 34 additions & 0 deletions backend/utils/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from enum import Enum
from typing import Optional

from rest_framework.exceptions import APIException


class InvalidEncryptionKey(APIException):
status_code = 403
default_detail = (
"Platform encryption key for storing sensitive credentials has changed! "
"All encrypted entities are inaccessible. Please inform the "
"platform admin immediately."
)

class Entity(Enum):
ADAPTER = "adapter"
CONNECTOR = "connector"

def __init__(
self,
entity: Optional[Entity] = None,
detail: Optional[str] = None,
code: Optional[str] = None,
) -> None:
if entity == self.Entity.ADAPTER:
detail = self.default_detail.replace("sensitive", "adapter").replace(
"encrypted entities", "adapters"
)
elif entity == self.Entity.CONNECTOR:
detail = self.default_detail.replace("sensitive", "connector").replace(
"encrypted entities", "connectors"
)

super().__init__(detail, code)
2 changes: 1 addition & 1 deletion frontend/src/hooks/useExceptionHandler.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ const useExceptionHandler = () => {
return {
title: title,
type: "error",
content: errors?.[0]?.detail ? errors[0].detail : errMessage,
content: errors,
duration: duration,
};
case "client_error":
Expand Down
4 changes: 4 additions & 0 deletions frontend/src/hooks/useSessionValid.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { useExceptionHandler } from "../hooks/useExceptionHandler.jsx";
import { useSessionStore } from "../store/session-store";
import { useUserSession } from "./useUserSession.js";
import { listFlags } from "../helpers/FeatureFlagsData.js";
import { useAlertStore } from "../store/alert-store";

let getTrialDetails;
let isPlatformAdmin;
Expand All @@ -30,6 +31,7 @@ try {
function useSessionValid() {
const setSessionDetails = useSessionStore((state) => state.setSessionDetails);
const handleException = useExceptionHandler();
const { setAlertDetails } = useAlertStore();
const navigate = useNavigate();
const userSession = useUserSession();

Expand Down Expand Up @@ -134,6 +136,7 @@ function useSessionValid() {
setSessionDetails(getSessionData(userAndOrgDetails));
} catch (err) {
// TODO: Throw popup error message
// REVIEW: Add condition to check for trial period status
if (err.response?.status === 402) {
handleException(err);
}
Expand All @@ -145,6 +148,7 @@ function useSessionValid() {
window.location.href = `/error?code=${code}&domain=${domainName}`;
// May be need a logout button there or auto logout
}
setAlertDetails(handleException(err));
}
};
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Any
import json
import traceback
from typing import Union

from flask import Blueprint, Response, jsonify
from unstract.platform_service.exceptions import CustomException
from flask import Blueprint, Response, jsonify, request
from unstract.platform_service.exceptions import APIError, ErrorResponse
from werkzeug.exceptions import HTTPException

from .health import health_bp
from .platform import platform_bp
Expand All @@ -11,8 +14,63 @@
api.register_blueprint(health_bp)


@api.errorhandler(CustomException)
def handle_custom_exception(error: Any) -> tuple[Response, Any]:
response = jsonify({"error": error.message})
response.status_code = error.code # You can customize the HTTP status code
return jsonify(response), error.code
def log_exceptions(e: HTTPException) -> None:
"""Helper method to log exceptions.
Args:
e (HTTPException): Exception to log
"""
code = 500
if hasattr(e, "code"):
code = e.code or code

if code >= 500:
message = "{method} {url} {status}\n\n{error}\n\n````{tb}````".format(
method=request.method,
url=request.url,
status=code,
error=str(e),
tb=traceback.format_exc(),
)
else:
message = "{method} {url} {status} {error}".format(
method=request.method,
url=request.url,
status=code,
error=str(e),
)

# Avoids circular import errors while initializing app context
from flask import current_app as app

app.logger.error(message)


@api.errorhandler(HTTPException)
def handle_http_exception(e: HTTPException) -> Union[Response, tuple[Response, int]]:
"""Return JSON instead of HTML for HTTP errors."""
log_exceptions(e)
if isinstance(e, APIError):
return jsonify(e.to_dict()), e.code
else:
response = e.get_response()
response.data = json.dumps(
ErrorResponse(error=e.description, name=e.name, code=e.code)
)
response.content_type = "application/json"
return response


@api.errorhandler(Exception)
def handle_uncaught_exception(e: Exception) -> Union[Response, tuple[Response, int]]:
"""Handler for uncaught exceptions.
Args:
e (Exception): Any uncaught exception
"""
# pass through HTTP errors
if isinstance(e, HTTPException):
return handle_http_exception(e)

log_exceptions(e)
return handle_http_exception(APIError())
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import json
import uuid
from datetime import datetime
from typing import Any, Optional

import peewee
import redis
from cryptography.fernet import Fernet
from cryptography.fernet import Fernet, InvalidToken
from flask import Blueprint, Request
from flask import current_app as app
from flask import json, jsonify, make_response, request
from flask import jsonify, make_response, request
from unstract.platform_service.constants import DBTable, DBTableV2, FeatureFlag
from unstract.platform_service.env import Env
from unstract.platform_service.exceptions import APIError
from unstract.platform_service.helper.adapter_instance import (
AdapterInstanceRequestHelper,
)
Expand Down Expand Up @@ -452,11 +454,21 @@ def adapter_instance() -> Any:
)

return jsonify(data_dict)
except InvalidToken:
msg = (
"Platform encryption key for storing adapter credentials has "
"changed! All adapters are inaccessible. Please inform "
"the platform admin immediately."
)
app.logger.error(
f"Error while getting db adapter settings for: "
f"{adapter_instance_id}, Error: {msg}"
)
raise APIError(message=msg, code=403)
except Exception as e:
print(e)
app.logger.error(
f"Error while getting db adapter settings for: "
f"{adapter_instance_id} Error: {str(e)}"
f"{adapter_instance_id}, Error: {str(e)}"
)
return "Internal Server Error", 500
return "Method Not Allowed", 405
Expand Down Expand Up @@ -492,10 +504,14 @@ def custom_tool_instance() -> Any:
)
return jsonify(data_dict)
except Exception as e:
print(e)
app.logger.error(
f"Error while getting db adapter settings for: "
f"{prompt_registry_id} Error: {str(e)}"
)
return "Internal Server Error", 500
return "Method Not Allowed", 405


if __name__ == "__main__":
# Start the server
app.run(host="0.0.0.0", port="3001")
Loading

0 comments on commit b3e3f4b

Please sign in to comment.