Skip to content

Commit

Permalink
feat(security): add security agent (#1266)
Browse files Browse the repository at this point in the history
* feat(security): add security agent

* fix ruff errors

* check for io and os after code generation

* fix lint errors
  • Loading branch information
ArslanSaleem authored Jul 3, 2024
1 parent f895e5f commit f98d3a9
Show file tree
Hide file tree
Showing 18 changed files with 929 additions and 4 deletions.
28 changes: 28 additions & 0 deletions docs/advanced-security-agent.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
---
title: "Advanced Security Agent"
description: "Enhance the PandasAI library with the Security Agent to secure applications from malicious code generation"
---

## Introduction to the Advanced Security Agent

The `AdvancedSecurityAgent` (currently in beta) extends the capabilities of the PandasAI library by adding a Security layer to identify if query can generate malicious code.

> **Note:** Usage of the Security Agent may be subject to a license. For more details, refer to the [license documentation](https://github.com/Sinaptik-AI/pandas-ai/blob/master/pandasai/ee/LICENSE).
## Instantiating the Security Agent

Creating an instance of the `AdvancedSecurityAgent` is similar to creating an instance of an `Agent`.

```python
import os

from pandasai.agent.agent import Agent
from pandasai.ee.agents.advanced_security_agent import AdvancedSecurityAgent

os.environ["PANDASAI_API_KEY"] = "$2a****************************"

security = AdvancedSecurityAgent()
agent = Agent("github-stars.csv", security=security)

print(agent.chat("return total stars count"))
```
2 changes: 1 addition & 1 deletion docs/mint.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
},
{
"group": "Advanced agents",
"pages": ["semantic-agent", "judge-agent"]
"pages": ["semantic-agent", "judge-agent", "advanced-security-agent"]
},
{
"group": "Advanced usage",
Expand Down
18 changes: 18 additions & 0 deletions examples/security_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os

from pandasai.agent.agent import Agent
from pandasai.ee.agents.security_agent import SecurityAgent
from pandasai.llm.openai import OpenAI

os.environ["PANDASAI_API_KEY"] = "$2a****************************"

security = SecurityAgent()
agent = Agent("github-stars.csv", security=security)

print(agent.chat("return total stars count"))


# Using Security standalone
llm = OpenAI("openai_key")
security = SecurityAgent(config={"llm": llm})
security.evaluate("return total github star count for year 2023")
6 changes: 5 additions & 1 deletion pandasai/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pandasai.agent.base import BaseAgent
from pandasai.agent.base_judge import BaseJudge
from pandasai.agent.base_security import BaseSecurity
from pandasai.connectors.base import BaseConnector
from pandasai.pipelines.chat.generate_chat_pipeline import GenerateChatPipeline
from pandasai.schemas.df_config import Config
Expand All @@ -22,8 +23,11 @@ def __init__(
vectorstore: Optional[VectorStore] = None,
description: str = None,
judge: BaseJudge = None,
security: BaseSecurity = None,
):
super().__init__(dfs, config, memory_size, vectorstore, description)
super().__init__(
dfs, config, memory_size, vectorstore, description, security=security
)

self.pipeline = (
pipeline(
Expand Down
22 changes: 21 additions & 1 deletion pandasai/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Optional, Union

import pandasai.pandas as pd
from pandasai.agent.base_security import BaseSecurity
from pandasai.llm.bamboo_llm import BambooLLM
from pandasai.pipelines.chat.chat_pipeline_input import ChatPipelineInput
from pandasai.pipelines.chat.code_execution_pipeline_input import (
Expand All @@ -14,7 +15,11 @@
from ..config import load_config_from_json
from ..connectors import BaseConnector, PandasConnector
from ..constants import DEFAULT_CACHE_DIRECTORY, DEFAULT_CHART_DIRECTORY
from ..exceptions import InvalidLLMOutputType, MissingVectorStoreError
from ..exceptions import (
InvalidLLMOutputType,
MaliciousQueryError,
MissingVectorStoreError,
)
from ..helpers.df_info import df_type
from ..helpers.folder import Folder
from ..helpers.logger import Logger
Expand Down Expand Up @@ -45,6 +50,7 @@ def __init__(
memory_size: Optional[int] = 10,
vectorstore: Optional[VectorStore] = None,
description: str = None,
security: BaseSecurity = None,
):
"""
Args:
Expand Down Expand Up @@ -97,6 +103,7 @@ def __init__(
self.configure()

self.pipeline = None
self.security = security

def configure(self):
# Add project root path if save_charts_path is default
Expand Down Expand Up @@ -226,6 +233,10 @@ def call_llm_with_prompt(self, prompt: BasePrompt):
raise
retry_count += 1

def check_malicious_keywords_in_query(self, query):
dangerous_modules = [" os", " io", ".os", ".io"]
return any(module in query for module in dangerous_modules)

def chat(self, query: str, output_type: Optional[str] = None):
"""
Simulate a chat interaction with the assistant on Dataframe.
Expand All @@ -244,11 +255,20 @@ def chat(self, query: str, output_type: Optional[str] = None):

self.assign_prompt_id()

if self.check_malicious_keywords_in_query(query):
raise MaliciousQueryError(
"Query can result in a malicious code, query contain io and os which can lead to malicious code"
)

if self.security and self.security.evaluate(query):
raise MaliciousQueryError("Query can result in a malicious code")

pipeline_input = ChatPipelineInput(
query, output_type, self.conversation_id, self.last_prompt_id
)

return self.pipeline.run(pipeline_input)

except Exception as exception:
return (
"Unfortunately, I was not able to get your answers, "
Expand Down
18 changes: 18 additions & 0 deletions pandasai/agent/base_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pandasai.helpers.logger import Logger
from pandasai.pipelines.pipeline import Pipeline
from pandasai.pipelines.pipeline_context import PipelineContext


class BaseSecurity:
context: PipelineContext
pipeline: Pipeline
logger: Logger

def __init__(
self,
pipeline: Pipeline,
) -> None:
self.pipeline = pipeline

def evaluate(self, query: str) -> bool:
raise NotImplementedError
32 changes: 32 additions & 0 deletions pandasai/ee/agents/advanced_security_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Optional, Union

from pandasai.agent.base_security import BaseSecurity
from pandasai.config import load_config_from_json
from pandasai.ee.agents.advanced_security_agent.pipeline.advanced_security_pipeline import (
AdvancedSecurityPipeline,
)
from pandasai.pipelines.abstract_pipeline import AbstractPipeline
from pandasai.pipelines.pipeline_context import PipelineContext
from pandasai.schemas.df_config import Config


class AdvancedSecurityAgent(BaseSecurity):
def __init__(
self,
config: Optional[Union[Config, dict]] = None,
pipeline: AbstractPipeline = None,
) -> None:
context = None

if isinstance(config, dict):
config = Config(**load_config_from_json(config))
elif config is None:
config = Config()

context = PipelineContext(None, config)

pipeline = pipeline or AdvancedSecurityPipeline(context=context)
super().__init__(pipeline)

def evaluate(self, query: str) -> bool:
return self.pipeline.run(query)
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Optional

from pandasai.ee.agents.advanced_security_agent.pipeline.advanced_security_prompt_generation import (
AdvancedSecurityPromptGeneration,
)
from pandasai.ee.agents.judge_agent.pipeline.llm_call import LLMCall
from pandasai.helpers.logger import Logger
from pandasai.helpers.query_exec_tracker import QueryExecTracker
from pandasai.pipelines.pipeline import Pipeline
from pandasai.pipelines.pipeline_context import PipelineContext


class AdvancedSecurityPipeline:
def __init__(
self,
context: Optional[PipelineContext] = None,
logger: Optional[Logger] = None,
query_exec_tracker: QueryExecTracker = None,
):
self.query_exec_tracker = query_exec_tracker

self.pipeline = Pipeline(
context=context,
logger=logger,
query_exec_tracker=self.query_exec_tracker,
steps=[
AdvancedSecurityPromptGeneration(),
LLMCall(),
],
)

def run(self, input: str):
return self.pipeline.run(input)
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Any

from pandasai.ee.agents.advanced_security_agent.prompts.advanced_security_agent_prompt import (
AdvancedSecurityAgentPrompt,
)
from pandasai.helpers.logger import Logger
from pandasai.pipelines.base_logic_unit import BaseLogicUnit
from pandasai.pipelines.logic_unit_output import LogicUnitOutput


class AdvancedSecurityPromptGeneration(BaseLogicUnit):
"""
Code Prompt Generation Stage
"""

pass

def execute(self, input_query: str, **kwargs) -> Any:
"""
This method will return output according to
Implementation.
:param input: Last logic unit output
:param kwargs: A dictionary of keyword arguments.
- 'logger' (any): The logger for logging.
- 'config' (Config): Global configurations for the test
- 'context' (any): The execution context.
:return: LogicUnitOutput(prompt)
"""
self.context = kwargs.get("context")
self.logger: Logger = kwargs.get("logger")

prompt = AdvancedSecurityAgentPrompt(query=input_query, context=self.context)
self.logger.log(f"Using prompt: {prompt}")

return LogicUnitOutput(
prompt,
True,
"Prompt Generated Successfully",
{"content_type": "prompt", "value": prompt.to_string()},
)
64 changes: 64 additions & 0 deletions pandasai/ee/agents/advanced_security_agent/pipeline/llm_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Any

from pandasai.exceptions import InvalidOutputValueMismatch
from pandasai.helpers.logger import Logger
from pandasai.pipelines.base_logic_unit import BaseLogicUnit
from pandasai.pipelines.logic_unit_output import LogicUnitOutput
from pandasai.pipelines.pipeline_context import PipelineContext


class LLMCall(BaseLogicUnit):
"""
LLM Code Generation Stage
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

def execute(self, input: Any, **kwargs) -> Any:
"""
This method will return output according to
Implementation.
:param input: Your input data.
:param kwargs: A dictionary of keyword arguments.
- 'logger' (any): The logger for logging.
- 'config' (Config): Global configurations for the test
- 'context' (any): The execution context.
:return: The result of the execution.
"""
pipeline_context: PipelineContext = kwargs.get("context")
logger: Logger = kwargs.get("logger")

retry_count = 0
while retry_count <= pipeline_context.config.max_retries:
response = pipeline_context.config.llm.call(input, pipeline_context)

logger.log(
f"""LLM response:
{response}
"""
)
try:
result = False
if "<Yes>" in response:
result = True
elif "<No>" in response:
result = False
else:
raise InvalidOutputValueMismatch("Invalid response of LLM Call")

pipeline_context.add("llm_call", response)

return LogicUnitOutput(
result,
True,
"Code Generated Successfully",
{"content_type": "string", "value": response},
)
except Exception:
if retry_count == pipeline_context.config.max_retries:
raise

retry_count += 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from pathlib import Path

from jinja2 import Environment, FileSystemLoader

from pandasai.prompts.base import BasePrompt


class AdvancedSecurityAgentPrompt(BasePrompt):
"""Prompt to generate Python code from a dataframe."""

template_path = "advanced_security_agent_prompt.tmpl"

def __init__(self, **kwargs):
"""Initialize the prompt."""
self.props = kwargs

if self.template:
env = Environment()
self.prompt = env.from_string(self.template)
elif self.template_path:
# find path to template file
current_dir_path = Path(__file__).parent

path_to_template = current_dir_path / "templates"
env = Environment(loader=FileSystemLoader(path_to_template))
self.prompt = env.get_template(self.template_path)

self._resolved_prompt = None

def to_json(self):
context = self.props["context"]
memory = context.memory
conversations = memory.to_json()
system_prompt = memory.get_system_prompt()
return {
"conversation": conversations,
"system_prompt": system_prompt,
"prompt": self.to_string(),
}
Loading

0 comments on commit f98d3a9

Please sign in to comment.