Skip to content

Commit

Permalink
Pebblo: Policy enforcement in Safe Retriever
Browse files Browse the repository at this point in the history
  • Loading branch information
Raj725 committed Oct 29, 2024
1 parent c3021e9 commit 4f43181
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 23 deletions.
68 changes: 59 additions & 9 deletions libs/community/langchain_community/chains/pebblo_retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
SUPPORTED_VECTORSTORES,
set_enforcement_filters,
update_enforcement_filters,
)
from langchain_community.chains.pebblo_retrieval.models import (
App,
Expand Down Expand Up @@ -102,18 +102,27 @@ def _call(
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key)
semantic_context = inputs.get(self.semantic_context_key)
is_privileged_user = self.pb_client.is_privileged_user(auth_context)
semantic_context = self.determine_semantic_context(
is_privileged_user, auth_context, inputs
)
_, prompt_entities = self.pb_client.check_prompt_validity(question)

accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(
question, auth_context, semantic_context, run_manager=_run_manager
question,
auth_context,
semantic_context,
is_privileged_user,
run_manager=_run_manager,
)
else:
docs = self._get_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
docs = self._get_docs(
question, auth_context, semantic_context, is_privileged_user
) # type: ignore[call-arg]
answer = self.combine_documents_chain.run(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
Expand Down Expand Up @@ -155,7 +164,10 @@ async def _acall(
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key)
semantic_context = inputs.get(self.semantic_context_key)
is_privileged_user = self.pb_client.is_privileged_user(auth_context)
semantic_context = self.determine_semantic_context(
is_privileged_user, auth_context, inputs
)
accepts_run_manager = (
"run_manager" in inspect.signature(self._aget_docs).parameters
)
Expand All @@ -164,10 +176,16 @@ async def _acall(

if accepts_run_manager:
docs = await self._aget_docs(
question, auth_context, semantic_context, run_manager=_run_manager
question,
auth_context,
semantic_context,
is_privileged_user,
run_manager=_run_manager,
)
else:
docs = await self._aget_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
docs = await self._aget_docs(
question, auth_context, semantic_context, is_privileged_user
) # type: ignore[call-arg]
answer = await self.combine_documents_chain.arun(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
Expand Down Expand Up @@ -254,6 +272,7 @@ def from_chain_type(
api_key=api_key,
classifier_location=classifier_location,
classifier_url=classifier_url,
app_name=app_name,
)
# send app discovery request
pb_client.send_app_discover(app)
Expand Down Expand Up @@ -289,11 +308,14 @@ def _get_docs(
question: str,
auth_context: Optional[AuthContext],
semantic_context: Optional[SemanticContext],
is_privileged_user: bool = False,
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
set_enforcement_filters(self.retriever, auth_context, semantic_context)
update_enforcement_filters(
self.retriever, auth_context, semantic_context, is_privileged_user
)
return self.retriever.get_relevant_documents(
question, callbacks=run_manager.get_child()
)
Expand All @@ -303,15 +325,43 @@ async def _aget_docs(
question: str,
auth_context: Optional[AuthContext],
semantic_context: Optional[SemanticContext],
is_privileged_user: bool = False,
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
set_enforcement_filters(self.retriever, auth_context, semantic_context)
update_enforcement_filters(
self.retriever, auth_context, semantic_context, is_privileged_user
)
return await self.retriever.aget_relevant_documents(
question, callbacks=run_manager.get_child()
)

def determine_semantic_context(
self,
is_privileged_user: bool,
auth_context: Optional[AuthContext],
inputs: Dict[str, Any],
) -> Optional[SemanticContext]:
"""
Determine semantic context based on the auth_context or inputs.
Args:
is_privileged_user (bool): If the user is a privileged user.
auth_context (Optional[AuthContext]): Authentication context.
inputs (Dict[str, Any]): Input dictionary containing various parameters.
Returns:
Optional[SemanticContext]: Resolved semantic context.
"""
semantic_context = None
if not is_privileged_user:
# Get semantic context from policy if present otherwise from inputs
semantic_context = self.pb_client.get_semantic_context(
auth_context
) or inputs.get(self.semantic_context_key)
return semantic_context

@staticmethod
def _get_app_details( # type: ignore
app_name: str, owner: str, description: str, llm: BaseLanguageModel, **kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,23 @@ def clear_enforcement_filters(retriever: VectorStoreRetriever) -> None:
)


def set_enforcement_filters(
def update_enforcement_filters(
retriever: VectorStoreRetriever,
auth_context: Optional[AuthContext],
semantic_context: Optional[SemanticContext],
is_privileged_user: bool = False,
) -> None:
"""
Set identity and semantic enforcement filters in the retriever.
Update identity and semantic enforcement filters in the retriever.
"""
# Clear existing enforcement filters
clear_enforcement_filters(retriever)
if auth_context is not None:
_set_identity_enforcement_filter(retriever, auth_context)
if semantic_context is not None:
_set_semantic_enforcement_filter(retriever, semantic_context)
# Set new enforcement filters if not a privileged user
if not is_privileged_user:
if auth_context is not None:
_set_identity_enforcement_filter(retriever, auth_context)
if semantic_context is not None:
_set_semantic_enforcement_filter(retriever, semantic_context)


def _apply_qdrant_semantic_filter(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Models for the PebbloRetrievalQA chain."""

from typing import Any, List, Optional, Union
from enum import Enum
from typing import Any, List, Optional, Set, Union

from pydantic import BaseModel

Expand All @@ -10,7 +11,7 @@ class AuthContext(BaseModel):

name: Optional[str] = None
user_id: str
user_auth: List[str]
user_auth: List[str] = []
"""List of user authorizations, which may include their User ID and
the groups they are part of"""

Expand Down Expand Up @@ -149,3 +150,54 @@ class Qa(BaseModel):
user: str
user_identities: Optional[List[str]]
classifier_location: str


class PolicyType(Enum):
"""Enums for policy types"""

IDENTITY = "identity"
APPLICATION = "application"
COST = "cost"


class SemanticGuardrail(BaseModel):
"""
Semantic Guardrail for Entities and Topics (Restrictions).
Attributes:
entities (Optional[Set[str]]): A set of entity restrictions.
topics (Optional[Set[str]]): A set of topic restrictions.
"""

entities: Set[str] = set()
topics: Set[str] = set()


class Policy(BaseModel):
"""
Policy base class.
Attributes:
schema_version (int): The schema version of the policy.
type (PolicyType): The type of policy.
"""

schema_version: int = 1
type: PolicyType

class Config:
extra = "ignore"


class IdentityPolicy(Policy):
"""
Policy for access control.
Attributes:
privileged_identities (Set[str]): List of identities with privileged access.
user_semantic_guardrail (dict[str, SemanticGuardrail]): Mapping of identities to
semantic guardrail restrictions.
"""

privileged_identities: Set[str] = set()
user_semantic_guardrail: dict[str, SemanticGuardrail] = {}
Loading

0 comments on commit 4f43181

Please sign in to comment.