From b3e3f4bf446d1392693870042986ccd27e8a417e Mon Sep 17 00:00:00 2001 From: Chandrasekharan M <117059509+chandrasekharan-zipstack@users.noreply.github.com> Date: Tue, 20 Aug 2024 10:25:52 +0530 Subject: [PATCH] fix: Encryption key error handling fix (#567) * Encryption key error handling fix, platform service error handling and logging improvement * Minor typehint fix * Fix for subscription expiry behaviour * Fixed merge related conflicts --- backend/adapter_processor/exceptions.py | 8 -- backend/adapter_processor/models.py | 23 +++--- backend/adapter_processor/serializers.py | 12 +-- backend/connector/models.py | 16 ++-- backend/connector/serializers.py | 7 +- .../prompt_studio_helper.py | 2 +- backend/utils/exceptions.py | 34 +++++++++ frontend/src/hooks/useExceptionHandler.jsx | 2 +- frontend/src/hooks/useSessionValid.js | 4 + .../platform_service/controller/__init__.py | 74 +++++++++++++++++-- .../platform_service/controller/platform.py | 26 +++++-- .../unstract/platform_service/exceptions.py | 51 +++++++++++-- .../helper/adapter_instance.py | 4 +- .../platform_service/helper/prompt_studio.py | 4 +- .../src/unstract/prompt_service/helper.py | 3 +- 15 files changed, 206 insertions(+), 64 deletions(-) create mode 100644 backend/utils/exceptions.py diff --git a/backend/adapter_processor/exceptions.py b/backend/adapter_processor/exceptions.py index a0f18febc..876775ca5 100644 --- a/backend/adapter_processor/exceptions.py +++ b/backend/adapter_processor/exceptions.py @@ -19,14 +19,6 @@ class InValidAdapterId(APIException): default_detail = "Adapter ID is not Valid." -class InvalidEncryptionKey(APIException): - status_code = 403 - default_detail = ( - "Platform encryption key for storing adapter credentials has changed! " - "Please inform the organization admin to contact support." - ) - - class InternalServiceError(APIException): status_code = 500 default_detail = "Internal Service error" diff --git a/backend/adapter_processor/models.py b/backend/adapter_processor/models.py index f37198d1b..f856bcf6b 100644 --- a/backend/adapter_processor/models.py +++ b/backend/adapter_processor/models.py @@ -4,13 +4,14 @@ from typing import Any from account.models import User -from cryptography.fernet import Fernet +from cryptography.fernet import Fernet, InvalidToken from django.conf import settings from django.db import models from django.db.models import QuerySet from unstract.sdk.adapters.adapterkit import Adapterkit from unstract.sdk.adapters.enums import AdapterTypes from unstract.sdk.adapters.exceptions import AdapterError +from utils.exceptions import InvalidEncryptionKey from utils.models.base_model import BaseModel ADAPTER_NAME_SIZE = 128 @@ -135,22 +136,24 @@ def create_adapter(self) -> None: self.save() - def get_adapter_meta_data(self) -> Any: - encryption_secret: str = settings.ENCRYPTION_KEY - f: Fernet = Fernet(encryption_secret.encode("utf-8")) + @property + def metadata(self) -> Any: + try: + encryption_secret: str = settings.ENCRYPTION_KEY + f: Fernet = Fernet(encryption_secret.encode("utf-8")) - adapter_metadata = json.loads( - f.decrypt(bytes(self.adapter_metadata_b).decode("utf-8")) - ) + adapter_metadata = json.loads( + f.decrypt(bytes(self.adapter_metadata_b).decode("utf-8")) + ) + except InvalidToken: + raise InvalidEncryptionKey(entity=InvalidEncryptionKey.Entity.ADAPTER) return adapter_metadata def get_context_window_size(self) -> int: - - adapter_metadata = self.get_adapter_meta_data() # Get the adapter_instance adapter_class = Adapterkit().get_adapter_class_by_adapter_id(self.adapter_id) try: - adapter_instance = adapter_class(adapter_metadata) + adapter_instance = adapter_class(self.metadata) return adapter_instance.get_context_window_size() except AdapterError as e: logger.warning(f"Unable to retrieve context window size - {e}") diff --git a/backend/adapter_processor/serializers.py b/backend/adapter_processor/serializers.py index aecb16a76..1e74c19ad 100644 --- a/backend/adapter_processor/serializers.py +++ b/backend/adapter_processor/serializers.py @@ -4,8 +4,7 @@ from account.serializer import UserSerializer from adapter_processor.adapter_processor import AdapterProcessor from adapter_processor.constants import AdapterKeys -from adapter_processor.exceptions import InvalidEncryptionKey -from cryptography.fernet import Fernet, InvalidToken +from cryptography.fernet import Fernet from django.conf import settings from rest_framework import serializers from rest_framework.serializers import ModelSerializer @@ -62,11 +61,7 @@ def to_representation(self, instance: AdapterInstance) -> dict[str, str]: rep: dict[str, str] = super().to_representation(instance) rep.pop(AdapterKeys.ADAPTER_METADATA_B) - - try: - adapter_metadata = instance.get_adapter_meta_data() - except InvalidToken: - raise InvalidEncryptionKey + adapter_metadata = instance.metadata rep[AdapterKeys.ADAPTER_METADATA] = adapter_metadata # Retrieve context window if adapter is a LLM # For other adapter types, context_window is not relevant. @@ -124,8 +119,7 @@ def to_representation(self, instance: AdapterInstance) -> dict[str, str]: rep[common.ICON] = AdapterProcessor.get_adapter_data_with_key( instance.adapter_id, common.ICON ) - adapter_metadata = instance.get_adapter_meta_data() - model = adapter_metadata.get("model") + model = instance.metadata.get("model") if model: rep["model"] = model diff --git a/backend/connector/models.py b/backend/connector/models.py index 55c228e47..958946efe 100644 --- a/backend/connector/models.py +++ b/backend/connector/models.py @@ -7,10 +7,11 @@ from connector_auth.models import ConnectorAuth from connector_processor.connector_processor import ConnectorProcessor from connector_processor.constants import ConnectorKeys -from cryptography.fernet import Fernet +from cryptography.fernet import Fernet, InvalidToken from django.conf import settings from django.db import models from project.models import Project +from utils.exceptions import InvalidEncryptionKey from utils.models.base_model import BaseModel from workflow_manager.workflow.models import Workflow @@ -111,11 +112,14 @@ def __str__(self) -> str: @property def metadata(self) -> Any: - encryption_secret: str = settings.ENCRYPTION_KEY - cipher_suite: Fernet = Fernet(encryption_secret.encode("utf-8")) - decrypted_value = cipher_suite.decrypt( - bytes(self.connector_metadata_b).decode("utf-8") - ) + try: + encryption_secret: str = settings.ENCRYPTION_KEY + cipher_suite: Fernet = Fernet(encryption_secret.encode("utf-8")) + decrypted_value = cipher_suite.decrypt( + bytes(self.connector_metadata_b).decode("utf-8") + ) + except InvalidToken: + raise InvalidEncryptionKey(entity=InvalidEncryptionKey.Entity.CONNECTOR) return json.loads(decrypted_value) class Meta: diff --git a/backend/connector/serializers.py b/backend/connector/serializers.py index 13fea7ae8..c5ce1054f 100644 --- a/backend/connector/serializers.py +++ b/backend/connector/serializers.py @@ -75,13 +75,8 @@ def to_representation(self, instance: ConnectorInstance) -> dict[str, str]: rep[ConnectorKeys.ICON] = ConnectorProcessor.get_connector_data_with_key( instance.connector_id, ConnectorKeys.ICON ) - encryption_secret: str = settings.ENCRYPTION_KEY - f: Fernet = Fernet(encryption_secret.encode("utf-8")) rep.pop(CIKey.CONNECTOR_METADATA_B) if instance.connector_metadata_b: - adapter_metadata = json.loads( - f.decrypt(bytes(instance.connector_metadata_b).decode("utf-8")) - ) - rep[CIKey.CONNECTOR_METADATA] = adapter_metadata + rep[CIKey.CONNECTOR_METADATA] = instance.metadata return rep diff --git a/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py b/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py index 4ed32ca04..f17c4fa37 100644 --- a/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py +++ b/backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py @@ -1072,7 +1072,7 @@ def _fetch_single_pass_response( if answer["status"] == "ERROR": error_message = answer.get("error", None) raise AnswerFetchError( - f"Error while fetching response for prompt. {error_message}" + f"Error while fetching response for prompt(s). {error_message}" ) output_response = json.loads(answer["structure_output"]) return output_response diff --git a/backend/utils/exceptions.py b/backend/utils/exceptions.py new file mode 100644 index 000000000..b2cd4fd74 --- /dev/null +++ b/backend/utils/exceptions.py @@ -0,0 +1,34 @@ +from enum import Enum +from typing import Optional + +from rest_framework.exceptions import APIException + + +class InvalidEncryptionKey(APIException): + status_code = 403 + default_detail = ( + "Platform encryption key for storing sensitive credentials has changed! " + "All encrypted entities are inaccessible. Please inform the " + "platform admin immediately." + ) + + class Entity(Enum): + ADAPTER = "adapter" + CONNECTOR = "connector" + + def __init__( + self, + entity: Optional[Entity] = None, + detail: Optional[str] = None, + code: Optional[str] = None, + ) -> None: + if entity == self.Entity.ADAPTER: + detail = self.default_detail.replace("sensitive", "adapter").replace( + "encrypted entities", "adapters" + ) + elif entity == self.Entity.CONNECTOR: + detail = self.default_detail.replace("sensitive", "connector").replace( + "encrypted entities", "connectors" + ) + + super().__init__(detail, code) diff --git a/frontend/src/hooks/useExceptionHandler.jsx b/frontend/src/hooks/useExceptionHandler.jsx index 4f7835e54..e11f8f5e8 100644 --- a/frontend/src/hooks/useExceptionHandler.jsx +++ b/frontend/src/hooks/useExceptionHandler.jsx @@ -40,7 +40,7 @@ const useExceptionHandler = () => { return { title: title, type: "error", - content: errors?.[0]?.detail ? errors[0].detail : errMessage, + content: errors, duration: duration, }; case "client_error": diff --git a/frontend/src/hooks/useSessionValid.js b/frontend/src/hooks/useSessionValid.js index fe34e91f2..30264b47c 100644 --- a/frontend/src/hooks/useSessionValid.js +++ b/frontend/src/hooks/useSessionValid.js @@ -7,6 +7,7 @@ import { useExceptionHandler } from "../hooks/useExceptionHandler.jsx"; import { useSessionStore } from "../store/session-store"; import { useUserSession } from "./useUserSession.js"; import { listFlags } from "../helpers/FeatureFlagsData.js"; +import { useAlertStore } from "../store/alert-store"; let getTrialDetails; let isPlatformAdmin; @@ -30,6 +31,7 @@ try { function useSessionValid() { const setSessionDetails = useSessionStore((state) => state.setSessionDetails); const handleException = useExceptionHandler(); + const { setAlertDetails } = useAlertStore(); const navigate = useNavigate(); const userSession = useUserSession(); @@ -134,6 +136,7 @@ function useSessionValid() { setSessionDetails(getSessionData(userAndOrgDetails)); } catch (err) { // TODO: Throw popup error message + // REVIEW: Add condition to check for trial period status if (err.response?.status === 402) { handleException(err); } @@ -145,6 +148,7 @@ function useSessionValid() { window.location.href = `/error?code=${code}&domain=${domainName}`; // May be need a logout button there or auto logout } + setAlertDetails(handleException(err)); } }; } diff --git a/platform-service/src/unstract/platform_service/controller/__init__.py b/platform-service/src/unstract/platform_service/controller/__init__.py index adfbe6331..95e7074ff 100644 --- a/platform-service/src/unstract/platform_service/controller/__init__.py +++ b/platform-service/src/unstract/platform_service/controller/__init__.py @@ -1,7 +1,10 @@ -from typing import Any +import json +import traceback +from typing import Union -from flask import Blueprint, Response, jsonify -from unstract.platform_service.exceptions import CustomException +from flask import Blueprint, Response, jsonify, request +from unstract.platform_service.exceptions import APIError, ErrorResponse +from werkzeug.exceptions import HTTPException from .health import health_bp from .platform import platform_bp @@ -11,8 +14,63 @@ api.register_blueprint(health_bp) -@api.errorhandler(CustomException) -def handle_custom_exception(error: Any) -> tuple[Response, Any]: - response = jsonify({"error": error.message}) - response.status_code = error.code # You can customize the HTTP status code - return jsonify(response), error.code +def log_exceptions(e: HTTPException) -> None: + """Helper method to log exceptions. + + Args: + e (HTTPException): Exception to log + """ + code = 500 + if hasattr(e, "code"): + code = e.code or code + + if code >= 500: + message = "{method} {url} {status}\n\n{error}\n\n````{tb}````".format( + method=request.method, + url=request.url, + status=code, + error=str(e), + tb=traceback.format_exc(), + ) + else: + message = "{method} {url} {status} {error}".format( + method=request.method, + url=request.url, + status=code, + error=str(e), + ) + + # Avoids circular import errors while initializing app context + from flask import current_app as app + + app.logger.error(message) + + +@api.errorhandler(HTTPException) +def handle_http_exception(e: HTTPException) -> Union[Response, tuple[Response, int]]: + """Return JSON instead of HTML for HTTP errors.""" + log_exceptions(e) + if isinstance(e, APIError): + return jsonify(e.to_dict()), e.code + else: + response = e.get_response() + response.data = json.dumps( + ErrorResponse(error=e.description, name=e.name, code=e.code) + ) + response.content_type = "application/json" + return response + + +@api.errorhandler(Exception) +def handle_uncaught_exception(e: Exception) -> Union[Response, tuple[Response, int]]: + """Handler for uncaught exceptions. + + Args: + e (Exception): Any uncaught exception + """ + # pass through HTTP errors + if isinstance(e, HTTPException): + return handle_http_exception(e) + + log_exceptions(e) + return handle_http_exception(APIError()) diff --git a/platform-service/src/unstract/platform_service/controller/platform.py b/platform-service/src/unstract/platform_service/controller/platform.py index 5cd3a4f4d..dbed388da 100644 --- a/platform-service/src/unstract/platform_service/controller/platform.py +++ b/platform-service/src/unstract/platform_service/controller/platform.py @@ -1,15 +1,17 @@ +import json import uuid from datetime import datetime from typing import Any, Optional import peewee import redis -from cryptography.fernet import Fernet +from cryptography.fernet import Fernet, InvalidToken from flask import Blueprint, Request from flask import current_app as app -from flask import json, jsonify, make_response, request +from flask import jsonify, make_response, request from unstract.platform_service.constants import DBTable, DBTableV2, FeatureFlag from unstract.platform_service.env import Env +from unstract.platform_service.exceptions import APIError from unstract.platform_service.helper.adapter_instance import ( AdapterInstanceRequestHelper, ) @@ -452,11 +454,21 @@ def adapter_instance() -> Any: ) return jsonify(data_dict) + except InvalidToken: + msg = ( + "Platform encryption key for storing adapter credentials has " + "changed! All adapters are inaccessible. Please inform " + "the platform admin immediately." + ) + app.logger.error( + f"Error while getting db adapter settings for: " + f"{adapter_instance_id}, Error: {msg}" + ) + raise APIError(message=msg, code=403) except Exception as e: - print(e) app.logger.error( f"Error while getting db adapter settings for: " - f"{adapter_instance_id} Error: {str(e)}" + f"{adapter_instance_id}, Error: {str(e)}" ) return "Internal Server Error", 500 return "Method Not Allowed", 405 @@ -492,10 +504,14 @@ def custom_tool_instance() -> Any: ) return jsonify(data_dict) except Exception as e: - print(e) app.logger.error( f"Error while getting db adapter settings for: " f"{prompt_registry_id} Error: {str(e)}" ) return "Internal Server Error", 500 return "Method Not Allowed", 405 + + +if __name__ == "__main__": + # Start the server + app.run(host="0.0.0.0", port="3001") diff --git a/platform-service/src/unstract/platform_service/exceptions.py b/platform-service/src/unstract/platform_service/exceptions.py index c9a0c89c7..97832719c 100644 --- a/platform-service/src/unstract/platform_service/exceptions.py +++ b/platform-service/src/unstract/platform_service/exceptions.py @@ -1,5 +1,46 @@ -class CustomException(Exception): - def __init__(self, message: str = "An error occurred", code: int = 500) -> None: - self.message = message - self.code = code - super().__init__(self.message) +from dataclasses import asdict, dataclass +from typing import Any, Optional + +from werkzeug.exceptions import HTTPException + +DEFAULT_ERR_MESSAGE = "Something went wrong" + + +@dataclass +class ErrorResponse: + """Represents error response from platform service.""" + + error: str = DEFAULT_ERR_MESSAGE + name: str = "PlatformServiceError" + code: int = 500 + payload: Optional[Any] = None + + +class APIError(HTTPException): + code = 500 + message = DEFAULT_ERR_MESSAGE + + def __init__( + self, + message: Optional[str] = None, + code: Optional[int] = None, + payload: Any = None, + ): + if message: + self.message = message + if code: + self.code = code + self.payload = payload + super().__init__(description=message) + + def to_dict(self) -> dict[str, Any]: + err = ErrorResponse( + error=self.message, + code=self.code, + payload=self.payload, + name=self.__class__.__name__, + ) + return asdict(err) + + def __str__(self) -> str: + return str(self.message) diff --git a/platform-service/src/unstract/platform_service/helper/adapter_instance.py b/platform-service/src/unstract/platform_service/helper/adapter_instance.py index 3a4c6aa25..b8bf5a418 100644 --- a/platform-service/src/unstract/platform_service/helper/adapter_instance.py +++ b/platform-service/src/unstract/platform_service/helper/adapter_instance.py @@ -2,7 +2,7 @@ import peewee from unstract.platform_service.constants import DBTableV2, FeatureFlag -from unstract.platform_service.exceptions import CustomException +from unstract.platform_service.exceptions import APIError from unstract.flags.feature_flag import check_feature_flag_status @@ -41,7 +41,7 @@ def get_adapter_instance_from_db( cursor = db_instance.execute_sql(query) result_row = cursor.fetchone() if not result_row: - raise CustomException(message="Adapter not found", code=404) + raise APIError(message="Adapter not found", code=404) columns = [desc[0] for desc in cursor.description] data_dict: dict[str, Any] = dict(zip(columns, result_row)) cursor.close() diff --git a/platform-service/src/unstract/platform_service/helper/prompt_studio.py b/platform-service/src/unstract/platform_service/helper/prompt_studio.py index 1a05f1468..5de4b2cae 100644 --- a/platform-service/src/unstract/platform_service/helper/prompt_studio.py +++ b/platform-service/src/unstract/platform_service/helper/prompt_studio.py @@ -2,7 +2,7 @@ import peewee from unstract.platform_service.constants import DBTableV2, FeatureFlag -from unstract.platform_service.exceptions import CustomException +from unstract.platform_service.exceptions import APIError from unstract.flags.feature_flag import check_feature_flag_status @@ -41,7 +41,7 @@ def get_prompt_instance_from_db( cursor = db_instance.execute_sql(query) result_row = cursor.fetchone() if not result_row: - raise CustomException(message="Custom Tool not found", code=404) + raise APIError(message="Custom Tool not found", code=404) columns = [desc[0] for desc in cursor.description] data_dict: dict[str, Any] = dict(zip(columns, result_row)) cursor.close() diff --git a/prompt-service/src/unstract/prompt_service/helper.py b/prompt-service/src/unstract/prompt_service/helper.py index 764502e29..82385c791 100644 --- a/prompt-service/src/unstract/prompt_service/helper.py +++ b/prompt-service/src/unstract/prompt_service/helper.py @@ -12,6 +12,7 @@ from unstract.prompt_service.constants import PromptServiceContants as PSKeys from unstract.prompt_service.exceptions import APIError, RateLimitError from unstract.sdk.exceptions import RateLimitError as SdkRateLimitError +from unstract.sdk.exceptions import SdkError from unstract.sdk.llm import LLM load_dotenv() @@ -292,7 +293,7 @@ def run_completion( # TODO: Catch and handle specific exception here except SdkRateLimitError as e: raise RateLimitError(f"Rate limit error. {str(e)}") from e - except Exception as e: + except SdkError as e: logger.error(f"Error fetching response for prompt: {e}.") # TODO: Publish this error as a FE update raise APIError(str(e)) from e