-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(security): add security agent (#1266)
* feat(security): add security agent * fix ruff errors * check for io and os after code generation * fix lint errors
- Loading branch information
1 parent
f895e5f
commit f98d3a9
Showing
18 changed files
with
929 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
--- | ||
title: "Advanced Security Agent" | ||
description: "Enhance the PandasAI library with the Security Agent to secure applications from malicious code generation" | ||
--- | ||
|
||
## Introduction to the Advanced Security Agent | ||
|
||
The `AdvancedSecurityAgent` (currently in beta) extends the capabilities of the PandasAI library by adding a Security layer to identify if query can generate malicious code. | ||
|
||
> **Note:** Usage of the Security Agent may be subject to a license. For more details, refer to the [license documentation](https://github.com/Sinaptik-AI/pandas-ai/blob/master/pandasai/ee/LICENSE). | ||
## Instantiating the Security Agent | ||
|
||
Creating an instance of the `AdvancedSecurityAgent` is similar to creating an instance of an `Agent`. | ||
|
||
```python | ||
import os | ||
|
||
from pandasai.agent.agent import Agent | ||
from pandasai.ee.agents.advanced_security_agent import AdvancedSecurityAgent | ||
|
||
os.environ["PANDASAI_API_KEY"] = "$2a****************************" | ||
|
||
security = AdvancedSecurityAgent() | ||
agent = Agent("github-stars.csv", security=security) | ||
|
||
print(agent.chat("return total stars count")) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
|
||
from pandasai.agent.agent import Agent | ||
from pandasai.ee.agents.security_agent import SecurityAgent | ||
from pandasai.llm.openai import OpenAI | ||
|
||
os.environ["PANDASAI_API_KEY"] = "$2a****************************" | ||
|
||
security = SecurityAgent() | ||
agent = Agent("github-stars.csv", security=security) | ||
|
||
print(agent.chat("return total stars count")) | ||
|
||
|
||
# Using Security standalone | ||
llm = OpenAI("openai_key") | ||
security = SecurityAgent(config={"llm": llm}) | ||
security.evaluate("return total github star count for year 2023") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from pandasai.helpers.logger import Logger | ||
from pandasai.pipelines.pipeline import Pipeline | ||
from pandasai.pipelines.pipeline_context import PipelineContext | ||
|
||
|
||
class BaseSecurity: | ||
context: PipelineContext | ||
pipeline: Pipeline | ||
logger: Logger | ||
|
||
def __init__( | ||
self, | ||
pipeline: Pipeline, | ||
) -> None: | ||
self.pipeline = pipeline | ||
|
||
def evaluate(self, query: str) -> bool: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from typing import Optional, Union | ||
|
||
from pandasai.agent.base_security import BaseSecurity | ||
from pandasai.config import load_config_from_json | ||
from pandasai.ee.agents.advanced_security_agent.pipeline.advanced_security_pipeline import ( | ||
AdvancedSecurityPipeline, | ||
) | ||
from pandasai.pipelines.abstract_pipeline import AbstractPipeline | ||
from pandasai.pipelines.pipeline_context import PipelineContext | ||
from pandasai.schemas.df_config import Config | ||
|
||
|
||
class AdvancedSecurityAgent(BaseSecurity): | ||
def __init__( | ||
self, | ||
config: Optional[Union[Config, dict]] = None, | ||
pipeline: AbstractPipeline = None, | ||
) -> None: | ||
context = None | ||
|
||
if isinstance(config, dict): | ||
config = Config(**load_config_from_json(config)) | ||
elif config is None: | ||
config = Config() | ||
|
||
context = PipelineContext(None, config) | ||
|
||
pipeline = pipeline or AdvancedSecurityPipeline(context=context) | ||
super().__init__(pipeline) | ||
|
||
def evaluate(self, query: str) -> bool: | ||
return self.pipeline.run(query) |
33 changes: 33 additions & 0 deletions
33
pandasai/ee/agents/advanced_security_agent/pipeline/advanced_security_pipeline.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from typing import Optional | ||
|
||
from pandasai.ee.agents.advanced_security_agent.pipeline.advanced_security_prompt_generation import ( | ||
AdvancedSecurityPromptGeneration, | ||
) | ||
from pandasai.ee.agents.judge_agent.pipeline.llm_call import LLMCall | ||
from pandasai.helpers.logger import Logger | ||
from pandasai.helpers.query_exec_tracker import QueryExecTracker | ||
from pandasai.pipelines.pipeline import Pipeline | ||
from pandasai.pipelines.pipeline_context import PipelineContext | ||
|
||
|
||
class AdvancedSecurityPipeline: | ||
def __init__( | ||
self, | ||
context: Optional[PipelineContext] = None, | ||
logger: Optional[Logger] = None, | ||
query_exec_tracker: QueryExecTracker = None, | ||
): | ||
self.query_exec_tracker = query_exec_tracker | ||
|
||
self.pipeline = Pipeline( | ||
context=context, | ||
logger=logger, | ||
query_exec_tracker=self.query_exec_tracker, | ||
steps=[ | ||
AdvancedSecurityPromptGeneration(), | ||
LLMCall(), | ||
], | ||
) | ||
|
||
def run(self, input: str): | ||
return self.pipeline.run(input) |
42 changes: 42 additions & 0 deletions
42
pandasai/ee/agents/advanced_security_agent/pipeline/advanced_security_prompt_generation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import Any | ||
|
||
from pandasai.ee.agents.advanced_security_agent.prompts.advanced_security_agent_prompt import ( | ||
AdvancedSecurityAgentPrompt, | ||
) | ||
from pandasai.helpers.logger import Logger | ||
from pandasai.pipelines.base_logic_unit import BaseLogicUnit | ||
from pandasai.pipelines.logic_unit_output import LogicUnitOutput | ||
|
||
|
||
class AdvancedSecurityPromptGeneration(BaseLogicUnit): | ||
""" | ||
Code Prompt Generation Stage | ||
""" | ||
|
||
pass | ||
|
||
def execute(self, input_query: str, **kwargs) -> Any: | ||
""" | ||
This method will return output according to | ||
Implementation. | ||
:param input: Last logic unit output | ||
:param kwargs: A dictionary of keyword arguments. | ||
- 'logger' (any): The logger for logging. | ||
- 'config' (Config): Global configurations for the test | ||
- 'context' (any): The execution context. | ||
:return: LogicUnitOutput(prompt) | ||
""" | ||
self.context = kwargs.get("context") | ||
self.logger: Logger = kwargs.get("logger") | ||
|
||
prompt = AdvancedSecurityAgentPrompt(query=input_query, context=self.context) | ||
self.logger.log(f"Using prompt: {prompt}") | ||
|
||
return LogicUnitOutput( | ||
prompt, | ||
True, | ||
"Prompt Generated Successfully", | ||
{"content_type": "prompt", "value": prompt.to_string()}, | ||
) |
64 changes: 64 additions & 0 deletions
64
pandasai/ee/agents/advanced_security_agent/pipeline/llm_call.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from typing import Any | ||
|
||
from pandasai.exceptions import InvalidOutputValueMismatch | ||
from pandasai.helpers.logger import Logger | ||
from pandasai.pipelines.base_logic_unit import BaseLogicUnit | ||
from pandasai.pipelines.logic_unit_output import LogicUnitOutput | ||
from pandasai.pipelines.pipeline_context import PipelineContext | ||
|
||
|
||
class LLMCall(BaseLogicUnit): | ||
""" | ||
LLM Code Generation Stage | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
def execute(self, input: Any, **kwargs) -> Any: | ||
""" | ||
This method will return output according to | ||
Implementation. | ||
:param input: Your input data. | ||
:param kwargs: A dictionary of keyword arguments. | ||
- 'logger' (any): The logger for logging. | ||
- 'config' (Config): Global configurations for the test | ||
- 'context' (any): The execution context. | ||
:return: The result of the execution. | ||
""" | ||
pipeline_context: PipelineContext = kwargs.get("context") | ||
logger: Logger = kwargs.get("logger") | ||
|
||
retry_count = 0 | ||
while retry_count <= pipeline_context.config.max_retries: | ||
response = pipeline_context.config.llm.call(input, pipeline_context) | ||
|
||
logger.log( | ||
f"""LLM response: | ||
{response} | ||
""" | ||
) | ||
try: | ||
result = False | ||
if "<Yes>" in response: | ||
result = True | ||
elif "<No>" in response: | ||
result = False | ||
else: | ||
raise InvalidOutputValueMismatch("Invalid response of LLM Call") | ||
|
||
pipeline_context.add("llm_call", response) | ||
|
||
return LogicUnitOutput( | ||
result, | ||
True, | ||
"Code Generated Successfully", | ||
{"content_type": "string", "value": response}, | ||
) | ||
except Exception: | ||
if retry_count == pipeline_context.config.max_retries: | ||
raise | ||
|
||
retry_count += 1 |
39 changes: 39 additions & 0 deletions
39
pandasai/ee/agents/advanced_security_agent/prompts/advanced_security_agent_prompt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from pathlib import Path | ||
|
||
from jinja2 import Environment, FileSystemLoader | ||
|
||
from pandasai.prompts.base import BasePrompt | ||
|
||
|
||
class AdvancedSecurityAgentPrompt(BasePrompt): | ||
"""Prompt to generate Python code from a dataframe.""" | ||
|
||
template_path = "advanced_security_agent_prompt.tmpl" | ||
|
||
def __init__(self, **kwargs): | ||
"""Initialize the prompt.""" | ||
self.props = kwargs | ||
|
||
if self.template: | ||
env = Environment() | ||
self.prompt = env.from_string(self.template) | ||
elif self.template_path: | ||
# find path to template file | ||
current_dir_path = Path(__file__).parent | ||
|
||
path_to_template = current_dir_path / "templates" | ||
env = Environment(loader=FileSystemLoader(path_to_template)) | ||
self.prompt = env.get_template(self.template_path) | ||
|
||
self._resolved_prompt = None | ||
|
||
def to_json(self): | ||
context = self.props["context"] | ||
memory = context.memory | ||
conversations = memory.to_json() | ||
system_prompt = memory.get_system_prompt() | ||
return { | ||
"conversation": conversations, | ||
"system_prompt": system_prompt, | ||
"prompt": self.to_string(), | ||
} |
Oops, something went wrong.