diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/base.py b/libs/community/langchain_community/chains/pebblo_retrieval/base.py index 2d4b550f1a999..8c8b30644a60f 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/base.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/base.py @@ -22,7 +22,7 @@ from langchain_community.chains.pebblo_retrieval.enforcement_filters import ( SUPPORTED_VECTORSTORES, - set_enforcement_filters, + update_enforcement_filters, ) from langchain_community.chains.pebblo_retrieval.models import ( App, @@ -102,7 +102,10 @@ def _call( _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] auth_context = inputs.get(self.auth_context_key) - semantic_context = inputs.get(self.semantic_context_key) + is_privileged_user = self.pb_client.is_privileged_user(auth_context) + semantic_context = self.determine_semantic_context( + is_privileged_user, auth_context, inputs + ) _, prompt_entities = self.pb_client.check_prompt_validity(question) accepts_run_manager = ( @@ -110,10 +113,16 @@ def _call( ) if accepts_run_manager: docs = self._get_docs( - question, auth_context, semantic_context, run_manager=_run_manager + question, + auth_context, + semantic_context, + is_privileged_user, + run_manager=_run_manager, ) else: - docs = self._get_docs(question, auth_context, semantic_context) # type: ignore[call-arg] + docs = self._get_docs( + question, auth_context, semantic_context, is_privileged_user + ) # type: ignore[call-arg] answer = self.combine_documents_chain.run( input_documents=docs, question=question, callbacks=_run_manager.get_child() ) @@ -155,7 +164,10 @@ async def _acall( _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] auth_context = inputs.get(self.auth_context_key) - semantic_context = inputs.get(self.semantic_context_key) + is_privileged_user = self.pb_client.is_privileged_user(auth_context) + semantic_context = self.determine_semantic_context( + is_privileged_user, auth_context, inputs + ) accepts_run_manager = ( "run_manager" in inspect.signature(self._aget_docs).parameters ) @@ -164,10 +176,16 @@ async def _acall( if accepts_run_manager: docs = await self._aget_docs( - question, auth_context, semantic_context, run_manager=_run_manager + question, + auth_context, + semantic_context, + is_privileged_user, + run_manager=_run_manager, ) else: - docs = await self._aget_docs(question, auth_context, semantic_context) # type: ignore[call-arg] + docs = await self._aget_docs( + question, auth_context, semantic_context, is_privileged_user + ) # type: ignore[call-arg] answer = await self.combine_documents_chain.arun( input_documents=docs, question=question, callbacks=_run_manager.get_child() ) @@ -254,6 +272,7 @@ def from_chain_type( api_key=api_key, classifier_location=classifier_location, classifier_url=classifier_url, + app_name=app_name, ) # send app discovery request pb_client.send_app_discover(app) @@ -289,11 +308,14 @@ def _get_docs( question: str, auth_context: Optional[AuthContext], semantic_context: Optional[SemanticContext], + is_privileged_user: bool = False, *, run_manager: CallbackManagerForChainRun, ) -> List[Document]: """Get docs.""" - set_enforcement_filters(self.retriever, auth_context, semantic_context) + update_enforcement_filters( + self.retriever, auth_context, semantic_context, is_privileged_user + ) return self.retriever.get_relevant_documents( question, callbacks=run_manager.get_child() ) @@ -303,15 +325,43 @@ async def _aget_docs( question: str, auth_context: Optional[AuthContext], semantic_context: Optional[SemanticContext], + is_privileged_user: bool = False, *, run_manager: AsyncCallbackManagerForChainRun, ) -> List[Document]: """Get docs.""" - set_enforcement_filters(self.retriever, auth_context, semantic_context) + update_enforcement_filters( + self.retriever, auth_context, semantic_context, is_privileged_user + ) return await self.retriever.aget_relevant_documents( question, callbacks=run_manager.get_child() ) + def determine_semantic_context( + self, + is_privileged_user: bool, + auth_context: Optional[AuthContext], + inputs: Dict[str, Any], + ) -> Optional[SemanticContext]: + """ + Determine semantic context based on the auth_context or inputs. + + Args: + is_privileged_user (bool): If the user is a privileged user. + auth_context (Optional[AuthContext]): Authentication context. + inputs (Dict[str, Any]): Input dictionary containing various parameters. + + Returns: + Optional[SemanticContext]: Resolved semantic context. + """ + semantic_context = None + if not is_privileged_user: + # Get semantic context from policy if present otherwise from inputs + semantic_context = self.pb_client.get_semantic_context( + auth_context + ) or inputs.get(self.semantic_context_key) + return semantic_context + @staticmethod def _get_app_details( # type: ignore app_name: str, owner: str, description: str, llm: BaseLanguageModel, **kwargs diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/enforcement_filters.py b/libs/community/langchain_community/chains/pebblo_retrieval/enforcement_filters.py index 570cbdfa783f8..df58543250065 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/enforcement_filters.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/enforcement_filters.py @@ -50,20 +50,23 @@ def clear_enforcement_filters(retriever: VectorStoreRetriever) -> None: ) -def set_enforcement_filters( +def update_enforcement_filters( retriever: VectorStoreRetriever, auth_context: Optional[AuthContext], semantic_context: Optional[SemanticContext], + is_privileged_user: bool = False, ) -> None: """ - Set identity and semantic enforcement filters in the retriever. + Update identity and semantic enforcement filters in the retriever. """ # Clear existing enforcement filters clear_enforcement_filters(retriever) - if auth_context is not None: - _set_identity_enforcement_filter(retriever, auth_context) - if semantic_context is not None: - _set_semantic_enforcement_filter(retriever, semantic_context) + # Set new enforcement filters if not a privileged user + if not is_privileged_user: + if auth_context is not None: + _set_identity_enforcement_filter(retriever, auth_context) + if semantic_context is not None: + _set_semantic_enforcement_filter(retriever, semantic_context) def _apply_qdrant_semantic_filter( diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/models.py b/libs/community/langchain_community/chains/pebblo_retrieval/models.py index 97e29769ced6f..9ce9411d4ea7d 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/models.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/models.py @@ -1,6 +1,7 @@ """Models for the PebbloRetrievalQA chain.""" -from typing import Any, List, Optional, Union +from enum import Enum +from typing import Any, List, Optional, Set, Union from pydantic import BaseModel @@ -10,7 +11,7 @@ class AuthContext(BaseModel): name: Optional[str] = None user_id: str - user_auth: List[str] + user_auth: List[str] = [] """List of user authorizations, which may include their User ID and the groups they are part of""" @@ -149,3 +150,54 @@ class Qa(BaseModel): user: str user_identities: Optional[List[str]] classifier_location: str + + +class PolicyType(Enum): + """Enums for policy types""" + + IDENTITY = "identity" + APPLICATION = "application" + COST = "cost" + + +class SemanticGuardrail(BaseModel): + """ + Semantic Guardrail for Entities and Topics (Restrictions). + + Attributes: + entities (Optional[Set[str]]): A set of entity restrictions. + topics (Optional[Set[str]]): A set of topic restrictions. + """ + + entities: Set[str] = set() + topics: Set[str] = set() + + +class Policy(BaseModel): + """ + Policy base class. + + Attributes: + schema_version (int): The schema version of the policy. + type (PolicyType): The type of policy. + """ + + schema_version: int = 1 + type: PolicyType + + class Config: + extra = "ignore" + + +class IdentityPolicy(Policy): + """ + Policy for access control. + + Attributes: + privileged_identities (Set[str]): List of identities with privileged access. + user_semantic_guardrail (dict[str, SemanticGuardrail]): Mapping of identities to + semantic guardrail restrictions. + """ + + privileged_identities: Set[str] = set() + user_semantic_guardrail: dict[str, SemanticGuardrail] = {} diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py b/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py index e6e36a505a947..3e90cecaaf9a4 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py @@ -2,6 +2,8 @@ import logging import os import platform +import threading +import time from enum import Enum from http import HTTPStatus from typing import Any, Dict, List, Optional, Tuple @@ -21,9 +23,11 @@ AuthContext, Context, Framework, + IdentityPolicy, Prompt, Qa, Runtime, + SemanticContext, ) logger = logging.getLogger(__name__) @@ -32,6 +36,7 @@ _DEFAULT_CLASSIFIER_URL = "http://localhost:8000" _DEFAULT_PEBBLO_CLOUD_URL = "https://api.daxa.ai" +_POLICY_REFRESH_INTERVAL_SEC = 300 # 5 minutes class Routes(str, Enum): @@ -40,6 +45,14 @@ class Routes(str, Enum): retrieval_app_discover = "/v1/app/discover" prompt = "/v1/prompt" prompt_governance = "/v1/prompt/governance" + policy = "/v1/app/policy" + + +class PolicyType(str, Enum): + """Policy type enumerator.""" + + IDENTITY = "identity" + APPLICATION = "application" def get_runtime() -> Tuple[Framework, Runtime]: @@ -100,6 +113,10 @@ class PebbloRetrievalAPIWrapper(BaseModel): """URL of the Pebblo Classifier""" cloud_url: Optional[str] """URL of the Pebblo Cloud""" + app_name: str + """Name of the app""" + policy_cache: Optional[IdentityPolicy] = None + """Local cache for the policy""" def __init__(self, **kwargs: Any): """Validate that api key in environment.""" @@ -113,6 +130,9 @@ def __init__(self, **kwargs: Any): kwargs, "cloud_url", "PEBBLO_CLOUD_URL", _DEFAULT_PEBBLO_CLOUD_URL ) super().__init__(**kwargs) + if self.api_key: + # Start a thread to fetch policy from the Pebblo cloud + self._start_policy_refresh_thread() def send_app_discover(self, app: App) -> None: """ @@ -122,7 +142,7 @@ def send_app_discover(self, app: App) -> None: app (App): App instance to be discovered. """ pebblo_resp = None - payload = app.dict(exclude_unset=True) + payload = app.model_dump(exclude_unset=True) if self.classifier_location == "local": # Send app details to local classifier @@ -334,6 +354,150 @@ async def acheck_prompt_validity( prompt_entities["entityCount"] = pebblo_resp.get("entityCount", 0) return is_valid_prompt, prompt_entities + def is_privileged_user(self, auth_context: Optional[AuthContext]) -> bool: + """ + Check if the user is a privileged user. + + Args: + auth_context (Optional[AuthContext]): Authentication context. + + Returns: + bool: True if the user is a privileged user, False otherwise. + """ + if not auth_context or not self.policy_cache: + return False + + # Get privileged_identities from the policy + privileged_identities = self.policy_cache.privileged_identities + if not privileged_identities: + logger.debug("Privileged identities not found in the policy.") + return False + + # Check if user is a privileged user + user_auth = auth_context.user_auth if auth_context.user_auth else [] + is_privileged_user = any( + _identity in privileged_identities for _identity in user_auth + ) + if is_privileged_user: + logger.debug(f"User {auth_context.user_id} is a privileged user.") + else: + logger.debug(f"User {auth_context.user_id} is not a privileged user.") + return is_privileged_user + + def get_semantic_context( + self, auth_context: Optional[AuthContext] + ) -> Optional[SemanticContext]: + """ + Generate semantic context based on the given auth context. + + Args: + auth_context (Optional[AuthContext]): Authentication context. + + Returns: + Optional[SemanticContext]: Semantic context. + """ + semantic_context = None + if not auth_context or not self.policy_cache: + return semantic_context + + user_semantic_guardrail = self.policy_cache.user_semantic_guardrail + if not user_semantic_guardrail: + return semantic_context + + _all_guardrails = [] + user_auth = auth_context.user_auth if auth_context.user_auth else [] + for identity in user_auth: + if guardrail := self.policy_cache.user_semantic_guardrail.get(identity): + _all_guardrails.append(guardrail) + + semantic_context = self._combine_all_semantic_filters(_all_guardrails) + return semantic_context + + def _start_policy_refresh_thread(self) -> None: + """Start a thread to fetch policy from the Pebblo cloud.""" + logger.info("Starting policy refresh thread.") + policy_thread = threading.Thread(target=self._fetch_policy, daemon=True) + policy_thread.start() + + def _fetch_policy(self) -> None: + """Fetch policy from the Pebblo cloud at regular intervals.""" + while True: + try: + # Fetch identity policy from the Pebblo cloud + resp_json = self._get_policy_from_cloud( + self.app_name, PolicyType.IDENTITY + ) + # Update the local cache with the fetched policy + if resp_json: + policy_type = resp_json.get("type") + if not policy_type: + logger.debug(f"Message: {resp_json.get('message')}") + self.policy_cache = None + elif policy_type == PolicyType.IDENTITY.value: + self.policy_cache = IdentityPolicy(**resp_json) + logger.debug(f"Policy cache updated: {self.policy_cache}") + else: + logger.warning(f"Policy type {policy_type} not supported.") + except Exception as e: + logger.warning(f"Failed to fetch policy: {e}") + # Sleep for the refresh interval + time.sleep(_POLICY_REFRESH_INTERVAL_SEC) + + def _get_policy_from_cloud(self, app_name: str, policy_type: PolicyType) -> Any: + """ + Get the policy for an app from the Pebblo Cloud. + + Args: + app_name (str): Name of the app. + policy_type (PolicyType): Type of policy to fetch. + + Returns: + Any: Json response from the Pebblo Cloud. + + """ + resp_json = None + policy_url = f"{self.cloud_url}{Routes.policy.value}" + headers = self._make_headers(cloud_request=True) + payload = {"app_name": app_name, "policy_type": policy_type.value} + response = self.make_request("POST", policy_url, headers, payload) + if response and response.status_code == HTTPStatus.OK: + resp_json = response.json() + else: + logger.warning(f"Failed to fetch policy for {app_name}") + return resp_json + + @staticmethod + def _combine_all_semantic_filters( + all_guardrails: Optional[list] = None, + ) -> Optional[SemanticContext]: + """ + Combine all provided guardrails to create a semantic context by finding the + intersection of all guardrails. + + Args: + all_guardrails (Optional[list]): List of guardrails. + + Returns: + Optional[SemanticContext]: The generated semantic context. + """ + if not all_guardrails: + return None + + # Find the intersection of all guardrails + entities_to_deny = set.intersection( + *[_guardrail.entities for _guardrail in all_guardrails] + ) + topics_to_deny = set.intersection( + *[_guardrail.topics for _guardrail in all_guardrails] + ) + + # Generate semantic context from the deny entities and topics + _semantic_context = dict() + _semantic_context["pebblo_semantic_entities"] = {"deny": entities_to_deny} + _semantic_context["pebblo_semantic_topics"] = {"deny": topics_to_deny} + semantic_context = SemanticContext(**_semantic_context) + return semantic_context + def _make_headers(self, cloud_request: bool = False) -> dict: """ Generate headers for the request. @@ -441,7 +605,7 @@ async def amake_request( timeout: int = 20, ) -> Any: """ - Make a async request to the Pebblo server/cloud API. + Make an async request to the Pebblo server/cloud API. Args: method (str): HTTP method (GET, POST, PUT, DELETE, etc.). @@ -539,4 +703,4 @@ def build_prompt_qa_payload( else [], classifier_location=self.classifier_location, ) - return qa.dict(exclude_unset=True) + return qa.model_dump(exclude_unset=True) diff --git a/libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py b/libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py index a2fb1dbd00920..30bf4efe8d43a 100644 --- a/libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py +++ b/libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py @@ -2,8 +2,8 @@ Unit tests for the PebbloRetrievalQA chain """ -from typing import List -from unittest.mock import Mock +from typing import Generator, List +from unittest.mock import Mock, patch import pytest from langchain_core.callbacks import ( @@ -70,10 +70,25 @@ def pebblo_retrieval_qa(retriever: FakeRetriever) -> PebbloRetrievalQA: description="description", app_name="app_name", ) - + pebblo_retrieval_qa.pb_client = Mock() + pebblo_retrieval_qa.pb_client.send_prompt = Mock() + pebblo_retrieval_qa.pb_client.enforce_identity_policy = Mock( + return_value=(None, None, False) + ) + pebblo_retrieval_qa.pb_client.check_prompt_validity = Mock( + return_value=(None, dict()) + ) return pebblo_retrieval_qa +@pytest.fixture +def mock_update_enforcement_filters() -> Generator[Mock, None, None]: + with patch( + "langchain_community.chains.pebblo_retrieval.base.update_enforcement_filters" + ) as mock: + yield mock + + def test_invoke(pebblo_retrieval_qa: PebbloRetrievalQA) -> None: """ Test that the invoke method returns a non-None result @@ -132,3 +147,47 @@ def test_validate_vectorstore(retriever: FakeRetriever) -> None: "Vectorstore must be an instance of one of the supported vectorstores" in str(exc_info.value) ) + + +@pytest.mark.parametrize( + "is_privileged_user, expected_count", + [ + (True, 0), # Privileged user + (False, 1), # Non-privileged user + ], +) +def test_policy_enforcement( + pebblo_retrieval_qa: PebbloRetrievalQA, + is_privileged_user: bool, + expected_count: int, +) -> None: + """ + Test policy enforcement for both Privileged user and Non-privileged user. + The get_semantic_context and _set_semantic_enforcement_filter methods should be + called based on the user's role. + Privileged user should not have any enforcement filters applied so these methods + should not be called. + """ + question = "Tell me the secret of the universe" + auth_context = AuthContext(user_id="user@email.com", user_auth=["group1", "group2"]) + semantic_ctx = SemanticContext( + **{ + "pebblo_semantic_topics": {"deny": ["harmful-advice"]}, + "pebblo_semantic_entities": {"deny": ["credit-card"]}, + } + ) + chain_input_obj = ChainInput(query=question, auth_context=auth_context) + + with patch.object( + pebblo_retrieval_qa.pb_client, + "is_privileged_user", + return_value=is_privileged_user, + ) as mock_is_privileged_user, patch.object( + pebblo_retrieval_qa.pb_client, "get_semantic_context", return_value=semantic_ctx + ) as mock_get_semantic_context, patch( + "langchain_community.chains.pebblo_retrieval.enforcement_filters._set_semantic_enforcement_filter" + ) as mock_set_semantic_enforcement_filter: + _ = pebblo_retrieval_qa.invoke(chain_input_obj.dict()) + assert mock_is_privileged_user.call_count == 1 + assert mock_get_semantic_context.call_count == expected_count + assert mock_set_semantic_enforcement_filter.call_count == expected_count