From 0246f53cf2691c70a7c41b41aab522508f864803 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Thu, 6 Feb 2025 16:49:18 +0100 Subject: [PATCH 1/2] feat: introduce per client prompts and use the one in kodu Using the client detection functionality, expose the ability to send customized system prompts per client, that will add more control to the result we provide to our tools --- prompts/default.yaml | 4 ++++ src/codegate/pipeline/factory.py | 4 +++- .../pipeline/system_prompt/codegate.py | 14 ++++++++++++-- src/codegate/prompts.py | 18 +++++++++++++----- 4 files changed, 32 insertions(+), 8 deletions(-) diff --git a/prompts/default.yaml b/prompts/default.yaml index 1b20ca00..aa70cda0 100644 --- a/prompts/default.yaml +++ b/prompts/default.yaml @@ -50,3 +50,7 @@ red_team: "You are a red team member conducting a security assessment. Identify # BlueTeam prompts blue_team: "You are a blue team member conducting a security assessment. Identify security controls, misconfigurations, and potential vulnerabilities." + +# Per client prompts +client_prompts: + kodu: "If malicious packages or leaked secrets are found, please end the task, sending the problems found embedded in tags" diff --git a/src/codegate/pipeline/factory.py b/src/codegate/pipeline/factory.py index 0fdd66c4..3d0bb326 100644 --- a/src/codegate/pipeline/factory.py +++ b/src/codegate/pipeline/factory.py @@ -29,7 +29,9 @@ def create_input_pipeline(self, client_type: ClientType) -> SequentialPipelinePr CodegateSecrets(), CodegateCli(), CodegateContextRetriever(), - SystemPrompt(Config.get_config().prompts.default_chat), + SystemPrompt( + Config.get_config().prompts.default_chat, Config.get_config().prompts.client_prompts + ), ] return SequentialPipelineProcessor( input_steps, diff --git a/src/codegate/pipeline/system_prompt/codegate.py b/src/codegate/pipeline/system_prompt/codegate.py index 76bcf9d1..ce95d2e8 100644 --- a/src/codegate/pipeline/system_prompt/codegate.py +++ b/src/codegate/pipeline/system_prompt/codegate.py @@ -1,5 +1,6 @@ from typing import Optional +from codegate.clients.clients import ClientType from litellm import ChatCompletionRequest, ChatCompletionSystemMessage from codegate.pipeline.base import ( @@ -16,8 +17,9 @@ class SystemPrompt(PipelineStep): the word "codegate" in the user message. """ - def __init__(self, system_prompt: str): + def __init__(self, system_prompt: str, client_prompts: str): self.codegate_system_prompt = system_prompt + self.client_prompts = client_prompts @property def name(self) -> str: @@ -36,6 +38,7 @@ async def _get_workspace_custom_instructions(self) -> str: async def _construct_system_prompt( self, + client: ClientType, wrksp_custom_instr: str, req_sys_prompt: Optional[str], should_add_codegate_sys_prompt: bool, @@ -59,6 +62,10 @@ def _start_or_append(existing_prompt: str, new_prompt: str) -> str: if req_sys_prompt and "codegate" not in req_sys_prompt.lower(): system_prompt = _start_or_append(system_prompt, req_sys_prompt) + # Add per client system prompt + if client and client.value in self.client_prompts: + system_prompt = _start_or_append(system_prompt, self.client_prompts[client.value]) + return system_prompt async def _should_add_codegate_system_prompt(self, context: PipelineContext) -> bool: @@ -92,7 +99,10 @@ async def process( req_sys_prompt = request_system_message.get("content") system_prompt = await self._construct_system_prompt( - wrksp_custom_instructions, req_sys_prompt, should_add_codegate_sys_prompt + context.client, + wrksp_custom_instructions, + req_sys_prompt, + should_add_codegate_sys_prompt, ) context.add_alert(self.name, trigger_string=system_prompt) if not request_system_message: diff --git a/src/codegate/prompts.py b/src/codegate/prompts.py index 63405a08..6629382c 100644 --- a/src/codegate/prompts.py +++ b/src/codegate/prompts.py @@ -41,11 +41,19 @@ def from_file(cls, prompt_path: Union[str, Path]) -> "PromptConfig": if not isinstance(prompt_data, dict): raise ConfigurationError("Prompts file must contain a YAML dictionary") - # Validate all values are strings - for key, value in prompt_data.items(): - if not isinstance(value, str): - raise ConfigurationError(f"Prompt '{key}' must be a string, got {type(value)}") - + def validate_prompts(data, parent_key=""): + """Recursively validate prompt values.""" + for key, value in data.items(): + full_key = f"{parent_key}.{key}" if parent_key else key + if isinstance(value, dict): + validate_prompts(value, full_key) # Recurse into nested dictionaries + elif not isinstance(value, str): + raise ConfigurationError( + f"Prompt '{full_key}' must be a string, got {type(value)}" + ) + + # Validate the entire structure + validate_prompts(prompt_data) return cls(prompts=prompt_data) except yaml.YAMLError as e: raise ConfigurationError(f"Failed to parse prompts file: {e}") From 76b7990e32a7953ed297d5fa614ccfdc954bd2b4 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Fri, 7 Feb 2025 10:00:09 +0100 Subject: [PATCH 2/2] fix system prompt --- src/codegate/pipeline/system_prompt/codegate.py | 2 +- tests/pipeline/system_prompt/test_system_prompt.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/codegate/pipeline/system_prompt/codegate.py b/src/codegate/pipeline/system_prompt/codegate.py index ce95d2e8..0dbf39a8 100644 --- a/src/codegate/pipeline/system_prompt/codegate.py +++ b/src/codegate/pipeline/system_prompt/codegate.py @@ -17,7 +17,7 @@ class SystemPrompt(PipelineStep): the word "codegate" in the user message. """ - def __init__(self, system_prompt: str, client_prompts: str): + def __init__(self, system_prompt: str, client_prompts: dict[str]): self.codegate_system_prompt = system_prompt self.client_prompts = client_prompts diff --git a/tests/pipeline/system_prompt/test_system_prompt.py b/tests/pipeline/system_prompt/test_system_prompt.py index 6ea36a93..c9d1937d 100644 --- a/tests/pipeline/system_prompt/test_system_prompt.py +++ b/tests/pipeline/system_prompt/test_system_prompt.py @@ -13,7 +13,7 @@ def test_init_with_system_message(self): Test initialization with a system message """ test_message = "Test system prompt" - step = SystemPrompt(system_prompt=test_message) + step = SystemPrompt(system_prompt=test_message, client_prompts={}) assert step.codegate_system_prompt == test_message @pytest.mark.asyncio @@ -28,7 +28,7 @@ async def test_process_system_prompt_insertion(self): # Create system prompt step system_prompt = "Security analysis system prompt" - step = SystemPrompt(system_prompt=system_prompt) + step = SystemPrompt(system_prompt=system_prompt, client_prompts={}) step._get_workspace_custom_instructions = AsyncMock(return_value="") # Mock the get_last_user_message method @@ -62,7 +62,7 @@ async def test_process_system_prompt_update(self): # Create system prompt step system_prompt = "Security analysis system prompt" - step = SystemPrompt(system_prompt=system_prompt) + step = SystemPrompt(system_prompt=system_prompt, client_prompts={}) step._get_workspace_custom_instructions = AsyncMock(return_value="") # Mock the get_last_user_message method @@ -97,7 +97,7 @@ async def test_edge_cases(self, edge_case): mock_context = Mock(spec=PipelineContext) system_prompt = "Security edge case prompt" - step = SystemPrompt(system_prompt=system_prompt) + step = SystemPrompt(system_prompt=system_prompt, client_prompts={}) step._get_workspace_custom_instructions = AsyncMock(return_value="") # Mock get_last_user_message to return None