Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: introduce per client prompts and use the one in kodu #963

Merged
merged 5 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions prompts/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,7 @@ red_team: "You are a red team member conducting a security assessment. Identify

# BlueTeam prompts
blue_team: "You are a blue team member conducting a security assessment. Identify security controls, misconfigurations, and potential vulnerabilities."

# Per client prompts
client_prompts:
kodu: "If malicious packages or leaked secrets are found, please end the task, sending the problems found embedded in <attempt_completion><result> tags"
4 changes: 3 additions & 1 deletion src/codegate/pipeline/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def create_input_pipeline(self, client_type: ClientType) -> SequentialPipelinePr
CodegateSecrets(),
CodegateCli(),
CodegateContextRetriever(),
SystemPrompt(Config.get_config().prompts.default_chat),
SystemPrompt(
Config.get_config().prompts.default_chat, Config.get_config().prompts.client_prompts
),
]
return SequentialPipelineProcessor(
input_steps,
Expand Down
14 changes: 12 additions & 2 deletions src/codegate/pipeline/system_prompt/codegate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

from codegate.clients.clients import ClientType
from litellm import ChatCompletionRequest, ChatCompletionSystemMessage

from codegate.pipeline.base import (
Expand All @@ -16,8 +17,9 @@ class SystemPrompt(PipelineStep):
the word "codegate" in the user message.
"""

def __init__(self, system_prompt: str):
def __init__(self, system_prompt: str, client_prompts: dict[str]):
self.codegate_system_prompt = system_prompt
self.client_prompts = client_prompts

@property
def name(self) -> str:
Expand All @@ -36,6 +38,7 @@ async def _get_workspace_custom_instructions(self) -> str:

async def _construct_system_prompt(
self,
client: ClientType,
wrksp_custom_instr: str,
req_sys_prompt: Optional[str],
should_add_codegate_sys_prompt: bool,
Expand All @@ -59,6 +62,10 @@ def _start_or_append(existing_prompt: str, new_prompt: str) -> str:
if req_sys_prompt and "codegate" not in req_sys_prompt.lower():
system_prompt = _start_or_append(system_prompt, req_sys_prompt)

# Add per client system prompt
if client and client.value in self.client_prompts:
system_prompt = _start_or_append(system_prompt, self.client_prompts[client.value])

return system_prompt

async def _should_add_codegate_system_prompt(self, context: PipelineContext) -> bool:
Expand Down Expand Up @@ -92,7 +99,10 @@ async def process(
req_sys_prompt = request_system_message.get("content")

system_prompt = await self._construct_system_prompt(
wrksp_custom_instructions, req_sys_prompt, should_add_codegate_sys_prompt
context.client,
wrksp_custom_instructions,
req_sys_prompt,
should_add_codegate_sys_prompt,
)
context.add_alert(self.name, trigger_string=system_prompt)
if not request_system_message:
Expand Down
18 changes: 13 additions & 5 deletions src/codegate/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,19 @@ def from_file(cls, prompt_path: Union[str, Path]) -> "PromptConfig":
if not isinstance(prompt_data, dict):
raise ConfigurationError("Prompts file must contain a YAML dictionary")

# Validate all values are strings
for key, value in prompt_data.items():
if not isinstance(value, str):
raise ConfigurationError(f"Prompt '{key}' must be a string, got {type(value)}")

def validate_prompts(data, parent_key=""):
"""Recursively validate prompt values."""
for key, value in data.items():
full_key = f"{parent_key}.{key}" if parent_key else key
if isinstance(value, dict):
validate_prompts(value, full_key) # Recurse into nested dictionaries
elif not isinstance(value, str):
raise ConfigurationError(
f"Prompt '{full_key}' must be a string, got {type(value)}"
)

# Validate the entire structure
validate_prompts(prompt_data)
return cls(prompts=prompt_data)
except yaml.YAMLError as e:
raise ConfigurationError(f"Failed to parse prompts file: {e}")
Expand Down
8 changes: 4 additions & 4 deletions tests/pipeline/system_prompt/test_system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_init_with_system_message(self):
Test initialization with a system message
"""
test_message = "Test system prompt"
step = SystemPrompt(system_prompt=test_message)
step = SystemPrompt(system_prompt=test_message, client_prompts={})
assert step.codegate_system_prompt == test_message

@pytest.mark.asyncio
Expand All @@ -28,7 +28,7 @@ async def test_process_system_prompt_insertion(self):

# Create system prompt step
system_prompt = "Security analysis system prompt"
step = SystemPrompt(system_prompt=system_prompt)
step = SystemPrompt(system_prompt=system_prompt, client_prompts={})
step._get_workspace_custom_instructions = AsyncMock(return_value="")

# Mock the get_last_user_message method
Expand Down Expand Up @@ -62,7 +62,7 @@ async def test_process_system_prompt_update(self):

# Create system prompt step
system_prompt = "Security analysis system prompt"
step = SystemPrompt(system_prompt=system_prompt)
step = SystemPrompt(system_prompt=system_prompt, client_prompts={})
step._get_workspace_custom_instructions = AsyncMock(return_value="")

# Mock the get_last_user_message method
Expand Down Expand Up @@ -97,7 +97,7 @@ async def test_edge_cases(self, edge_case):
mock_context = Mock(spec=PipelineContext)

system_prompt = "Security edge case prompt"
step = SystemPrompt(system_prompt=system_prompt)
step = SystemPrompt(system_prompt=system_prompt, client_prompts={})
step._get_workspace_custom_instructions = AsyncMock(return_value="")

# Mock get_last_user_message to return None
Expand Down