Skip to content

Commit 8e76c7a

Browse files
committed
feat: finished aggregated solution implementation for openai and anthropic
1 parent 1ce18e9 commit 8e76c7a

6 files changed

+226
-54
lines changed

src/ai/LLM/BaseLLMService.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC, abstractmethod
22
from enum import Enum
33
from typing import Dict, Optional, List, Union, Tuple
4+
from tqdm import tqdm
45
import logging
56

67
from data.Finding import Finding
@@ -22,12 +23,12 @@ def get_url(self) -> str:
2223
pass
2324

2425
@abstractmethod
25-
def _generate(self, prompt: str) -> Dict[str, str]:
26+
def _generate(self, prompt: str, json=False) -> Dict[str, str]:
2627
pass
2728

28-
def generate(self, prompt: str) -> Dict[str, str]:
29+
def generate(self, prompt: str, json=False) -> Dict[str, str]:
2930
try:
30-
return self._generate(prompt)
31+
return self._generate(prompt, json)
3132
except Exception as e:
3233
logger.error(f"Error generating response: {str(e)}")
3334
return {"error": str(e)}
@@ -110,7 +111,7 @@ def generate_aggregated_solution(self, findings: List[Finding]) -> List[Tuple[st
110111

111112
results = []
112113

113-
for group, meta_info in finding_groups:
114+
for group, meta_info in tqdm(finding_groups, desc="Generating aggregated solutions for group", unit="group"):
114115
prompt = self._get_aggregated_solution_prompt(group, meta_info)
115116
response = self.generate(prompt)
116117
solution = self._process_aggregated_solution_response(response)
@@ -134,7 +135,7 @@ def _get_findings_str_for_aggregation(self, findings, details=False) -> str:
134135

135136
def _subdivide_finding_group(self, findings: List[Finding]) -> List[Tuple[List[Finding], Dict]]:
136137
prompt = self._get_subdivision_prompt(findings)
137-
response = self.generate(prompt)
138+
response = self.generate(prompt, json=True)
138139
return self._process_subdivision_response(response, findings)
139140

140141
@abstractmethod

src/ai/LLM/Strategies/AnthropicService.py

+76-15
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import re
23
from typing import Dict, List, Optional, Union
34
from enum import Enum
45

@@ -8,12 +9,14 @@
89
from ai.LLM.LLMServiceMixin import LLMServiceMixin
910
from data.Finding import Finding
1011
from ai.LLM.Strategies.openai_prompts import (
11-
CLASSIFY_KIND_TEMPLATE,
12-
SHORT_RECOMMENDATION_TEMPLATE,
13-
GENERIC_LONG_RECOMMENDATION_TEMPLATE,
14-
SEARCH_TERMS_TEMPLATE,
15-
META_PROMPT_GENERATOR_TEMPLATE,
16-
LONG_RECOMMENDATION_TEMPLATE, COMBINE_DESCRIPTIONS_TEMPLATE,
12+
OPENAI_CLASSIFY_KIND_TEMPLATE,
13+
OPENAI_SHORT_RECOMMENDATION_TEMPLATE,
14+
OPENAI_GENERIC_LONG_RECOMMENDATION_TEMPLATE,
15+
OPENAI_SEARCH_TERMS_TEMPLATE,
16+
OPENAI_META_PROMPT_GENERATOR_TEMPLATE,
17+
OPENAI_LONG_RECOMMENDATION_TEMPLATE,
18+
OPENAI_COMBINE_DESCRIPTIONS_TEMPLATE,
19+
OPENAI_AGGREGATED_SOLUTION_TEMPLATE, OPENAI_SUBDIVISION_PROMPT_TEMPLATE,
1720
)
1821
from utils.text_tools import clean
1922
from config import config
@@ -70,7 +73,7 @@ def get_url(self) -> str:
7073
"""Get the URL for the Anthropic API (placeholder method)."""
7174
return "-"
7275

73-
def _generate(self, prompt: str) -> Dict[str, str]:
76+
def _generate(self, prompt: str, json=False) -> Dict[str, str]:
7477
"""
7578
Generate a response using the Anthropic API.
7679
@@ -81,29 +84,34 @@ def _generate(self, prompt: str) -> Dict[str, str]:
8184
Dict[str, str]: A dictionary containing the generated response.
8285
"""
8386
try:
87+
messages = [{"role": "user", "content": prompt}]
88+
if json:
89+
messages.append({"role": "assistant", "content": "Here is the JSON requested:\n{"})
8490
message = self.client.messages.create(
8591
max_tokens=1024,
86-
messages=[{"role": "user", "content": prompt}],
92+
messages=messages,
8793
model=self.model,
8894
)
8995
content = message.content[0].text
96+
if json:
97+
content = "{" + content
9098
return {"response": content}
9199
except Exception as e:
92100
return self.handle_api_error(e)
93101

94102
def _get_classification_prompt(self, options: str, field_name: str, finding_str: str) -> str:
95103
"""Generate the classification prompt for Anthropic."""
96-
return CLASSIFY_KIND_TEMPLATE.format(options=options, field_name=field_name, data=finding_str)
104+
return OPENAI_CLASSIFY_KIND_TEMPLATE.format(options=options, field_name=field_name, data=finding_str)
97105

98106
def _get_recommendation_prompt(self, finding: Finding, short: bool) -> str:
99107
"""Generate the recommendation prompt for Anthropic."""
100108
if short:
101-
return SHORT_RECOMMENDATION_TEMPLATE.format(data=str(finding))
109+
return OPENAI_SHORT_RECOMMENDATION_TEMPLATE.format(data=str(finding))
102110
elif finding.solution and finding.solution.short_description:
103111
finding.solution.add_to_metadata("used_meta_prompt", True)
104112
return self._generate_prompt_with_meta_prompts(finding)
105113
else:
106-
return GENERIC_LONG_RECOMMENDATION_TEMPLATE
114+
return OPENAI_GENERIC_LONG_RECOMMENDATION_TEMPLATE
107115

108116
def _process_recommendation_response(self, response: Dict[str, str], finding: Finding, short: bool) -> Union[
109117
str, List[str]]:
@@ -117,11 +125,11 @@ def _process_recommendation_response(self, response: Dict[str, str], finding: Fi
117125
def _generate_prompt_with_meta_prompts(self, finding: Finding) -> str:
118126
"""Generate a prompt with meta-prompts for long recommendations."""
119127
short_recommendation = finding.solution.short_description
120-
meta_prompt_generator = META_PROMPT_GENERATOR_TEMPLATE.format(finding=str(finding))
128+
meta_prompt_generator = OPENAI_META_PROMPT_GENERATOR_TEMPLATE.format(finding=str(finding))
121129
meta_prompt_response = self.generate(meta_prompt_generator)
122130
meta_prompts = clean(meta_prompt_response.get("response", ""), llm_service=self)
123131

124-
long_prompt = LONG_RECOMMENDATION_TEMPLATE.format(meta_prompts=meta_prompts)
132+
long_prompt = OPENAI_LONG_RECOMMENDATION_TEMPLATE.format(meta_prompts=meta_prompts)
125133

126134
finding.solution.add_to_metadata(
127135
"prompt_long_breakdown",
@@ -135,7 +143,7 @@ def _generate_prompt_with_meta_prompts(self, finding: Finding) -> str:
135143

136144
def _get_search_terms_prompt(self, finding: Finding) -> str:
137145
"""Generate the search terms prompt for Anthropic."""
138-
return SEARCH_TERMS_TEMPLATE.format(data=str(finding))
146+
return OPENAI_SEARCH_TERMS_TEMPLATE.format(data=str(finding))
139147

140148
def _process_search_terms_response(self, response: Dict[str, str], finding: Finding) -> str:
141149
"""Process the search terms response from Anthropic."""
@@ -144,6 +152,59 @@ def _process_search_terms_response(self, response: Dict[str, str], finding: Find
144152
return ""
145153
return clean(response["response"], llm_service=self)
146154

155+
def _get_subdivision_prompt(self, findings: List[Finding]) -> str:
156+
findings_str = self._get_findings_str_for_aggregation(findings)
157+
return OPENAI_SUBDIVISION_PROMPT_TEMPLATE.format(data=findings_str)
158+
159+
def _process_subdivision_response(self, response: Dict[str, str], findings: List[Finding]) -> List[Tuple[List[Finding], Dict]]:
160+
if "response" not in response:
161+
logger.warning("Failed to subdivide findings")
162+
return [(findings, {})] # Return all findings as a single group if subdivision fails
163+
164+
try:
165+
response = response["response"]
166+
# remove prefix ```json and suffix ```
167+
response = re.sub(r'^```json', '', response)
168+
response = re.sub(r'```$', '', response)
169+
subdivisions = json.loads(response)["subdivisions"]
170+
except json.JSONDecodeError:
171+
logger.error("Failed to parse JSON response")
172+
return [(findings, {})]
173+
except KeyError:
174+
logger.error("Unexpected JSON structure in response")
175+
return [(findings, {})]
176+
177+
result = []
178+
for subdivision in subdivisions:
179+
try:
180+
group_indices = [int(i.strip()) - 1 for i in subdivision["group"].split(',')]
181+
group = [findings[i] for i in group_indices if i < len(findings)]
182+
meta_info = {"reason": subdivision.get("reason", "")}
183+
if len(group) == 1:
184+
continue # Skip single-element groups for *aggregated* solutions
185+
result.append((group, meta_info))
186+
except ValueError:
187+
logger.error(f"Failed to parse group indices: {subdivision['group']}")
188+
continue
189+
except KeyError:
190+
logger.error("Unexpected subdivision structure")
191+
continue
192+
193+
return result
194+
195+
def _get_aggregated_solution_prompt(self, findings: List[Finding], meta_info: Dict) -> str:
196+
findings_str = self._get_findings_str_for_aggregation(findings, details=True)
197+
return OPENAI_AGGREGATED_SOLUTION_TEMPLATE.format(
198+
data=findings_str,
199+
meta_info=meta_info.get("reason", "")
200+
)
201+
202+
def _process_aggregated_solution_response(self, response: Dict[str, str]) -> str:
203+
if "response" not in response:
204+
logger.warning("Failed to generate an aggregated solution")
205+
return ""
206+
return clean(response["response"], llm_service=self)
207+
147208
def convert_dict_to_str(self, data: Dict) -> str:
148209
"""
149210
Convert a dictionary to a string representation.
@@ -171,7 +232,7 @@ def combine_descriptions(self, descriptions: List[str]) -> str:
171232
if len(descriptions) <= 1:
172233
return descriptions[0] if descriptions else ""
173234

174-
prompt = COMBINE_DESCRIPTIONS_TEMPLATE.format(data=descriptions)
235+
prompt = OPENAI_COMBINE_DESCRIPTIONS_TEMPLATE.format(data=descriptions)
175236

176237
response = self.generate(prompt)
177238
if "response" not in response:

src/ai/LLM/Strategies/OLLAMAService.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def get_url(self) -> str:
7979
@retry(
8080
stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60)
8181
)
82-
def _generate(self, prompt: str) -> Dict[str, str]:
82+
def _generate(self, prompt: str, json=True) -> Dict[str, str]:
83+
# The JSON Param is ignored by the OLLAMA server, it always returns JSON
8384
payload = {"prompt": prompt, **self.generate_payload}
8485
try:
8586
timeout = httpx.Timeout(timeout=300.0)

src/ai/LLM/Strategies/OpenAIService.py

+81-19
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
import json
2+
import re
23
from enum import Enum
3-
from typing import Dict, List, Optional, Union
4+
from typing import Dict, List, Optional, Union, Tuple
45

56
import openai
67

78
from ai.LLM.BaseLLMService import BaseLLMService
89
from ai.LLM.LLMServiceMixin import LLMServiceMixin
910
from data.Finding import Finding
1011
from ai.LLM.Strategies.openai_prompts import (
11-
CLASSIFY_KIND_TEMPLATE,
12-
SHORT_RECOMMENDATION_TEMPLATE,
13-
GENERIC_LONG_RECOMMENDATION_TEMPLATE,
14-
SEARCH_TERMS_TEMPLATE,
15-
META_PROMPT_GENERATOR_TEMPLATE,
16-
LONG_RECOMMENDATION_TEMPLATE, COMBINE_DESCRIPTIONS_TEMPLATE,
12+
OPENAI_CLASSIFY_KIND_TEMPLATE,
13+
OPENAI_SHORT_RECOMMENDATION_TEMPLATE,
14+
OPENAI_GENERIC_LONG_RECOMMENDATION_TEMPLATE,
15+
OPENAI_SEARCH_TERMS_TEMPLATE,
16+
OPENAI_META_PROMPT_GENERATOR_TEMPLATE,
17+
OPENAI_LONG_RECOMMENDATION_TEMPLATE,
18+
OPENAI_COMBINE_DESCRIPTIONS_TEMPLATE,
19+
OPENAI_AGGREGATED_SOLUTION_TEMPLATE, OPENAI_SUBDIVISION_PROMPT_TEMPLATE,
1720
)
1821
from utils.text_tools import clean
1922

@@ -25,7 +28,7 @@
2528

2629

2730
class OpenAIService(BaseLLMService, LLMServiceMixin):
28-
def __init__(self, api_key: Optional[str] = None, model: str = "gpt-4"):
31+
def __init__(self, api_key: Optional[str] = None, model: str = "gpt-4o"):
2932
"""
3033
Initialize the OpenAIService.
3134
@@ -57,27 +60,33 @@ def get_context_size(self) -> int:
5760
def get_url(self) -> str:
5861
return "-"
5962

60-
def _generate(self, prompt: str) -> Dict[str, str]:
63+
def _generate(self, prompt: str, json=False) -> Dict[str, str]:
6164
try:
62-
response = openai.chat.completions.create(
63-
model=self.model, messages=[{"role": "user", "content": prompt}]
64-
)
65+
params = {
66+
"model": self.model,
67+
"messages": [{"role": "user", "content": prompt}]
68+
}
69+
70+
if json:
71+
params["response_format"] = {"type": "json_object"}
72+
73+
response = openai.chat.completions.create(**params)
6574
content = response.choices[0].message.content
6675
return {"response": content}
6776
except Exception as e:
6877
return self.handle_api_error(e)
6978

7079
def _get_classification_prompt(self, options: str, field_name: str, finding_str: str) -> str:
71-
return CLASSIFY_KIND_TEMPLATE.format(options=options, field_name=field_name, data=finding_str)
80+
return OPENAI_CLASSIFY_KIND_TEMPLATE.format(options=options, field_name=field_name, data=finding_str)
7281

7382
def _get_recommendation_prompt(self, finding: Finding, short: bool) -> str:
7483
if short:
75-
return SHORT_RECOMMENDATION_TEMPLATE.format(data=str(finding))
84+
return OPENAI_SHORT_RECOMMENDATION_TEMPLATE.format(data=str(finding))
7685
elif finding.solution and finding.solution.short_description:
7786
finding.solution.add_to_metadata("used_meta_prompt", True)
7887
return self._generate_prompt_with_meta_prompts(finding)
7988
else:
80-
return GENERIC_LONG_RECOMMENDATION_TEMPLATE
89+
return OPENAI_GENERIC_LONG_RECOMMENDATION_TEMPLATE
8190

8291
def _process_recommendation_response(self, response: Dict[str, str], finding: Finding, short: bool) -> Union[
8392
str, List[str]]:
@@ -89,11 +98,11 @@ def _process_recommendation_response(self, response: Dict[str, str], finding: Fi
8998

9099
def _generate_prompt_with_meta_prompts(self, finding: Finding) -> str:
91100
short_recommendation = finding.solution.short_description
92-
meta_prompt_generator = META_PROMPT_GENERATOR_TEMPLATE.format(finding=str(finding))
101+
meta_prompt_generator = OPENAI_META_PROMPT_GENERATOR_TEMPLATE.format(finding=str(finding))
93102
meta_prompt_response = self.generate(meta_prompt_generator)
94103
meta_prompts = clean(meta_prompt_response.get("response", ""), llm_service=self)
95104

96-
long_prompt = LONG_RECOMMENDATION_TEMPLATE.format(meta_prompts=meta_prompts)
105+
long_prompt = OPENAI_LONG_RECOMMENDATION_TEMPLATE.format(meta_prompts=meta_prompts)
97106

98107
finding.solution.add_to_metadata(
99108
"prompt_long_breakdown",
@@ -106,14 +115,67 @@ def _generate_prompt_with_meta_prompts(self, finding: Finding) -> str:
106115
return long_prompt
107116

108117
def _get_search_terms_prompt(self, finding: Finding) -> str:
109-
return SEARCH_TERMS_TEMPLATE.format(data=str(finding))
118+
return OPENAI_SEARCH_TERMS_TEMPLATE.format(data=str(finding))
110119

111120
def _process_search_terms_response(self, response: Dict[str, str], finding: Finding) -> str:
112121
if "response" not in response:
113122
logger.warning(f"Failed to generate search terms for the finding: {finding.title}")
114123
return ""
115124
return clean(response["response"], llm_service=self)
116125

126+
def _get_subdivision_prompt(self, findings: List[Finding]) -> str:
127+
findings_str = self._get_findings_str_for_aggregation(findings)
128+
return OPENAI_SUBDIVISION_PROMPT_TEMPLATE.format(data=findings_str)
129+
130+
def _process_subdivision_response(self, response: Dict[str, str], findings: List[Finding]) -> List[Tuple[List[Finding], Dict]]:
131+
if "response" not in response:
132+
logger.warning("Failed to subdivide findings")
133+
return [(findings, {})] # Return all findings as a single group if subdivision fails
134+
135+
try:
136+
response = response["response"]
137+
# remove prefix ```json and suffix ```
138+
response = re.sub(r'^```json', '', response)
139+
response = re.sub(r'```$', '', response)
140+
subdivisions = json.loads(response)["subdivisions"]
141+
except json.JSONDecodeError:
142+
logger.error("Failed to parse JSON response")
143+
return [(findings, {})]
144+
except KeyError:
145+
logger.error("Unexpected JSON structure in response")
146+
return [(findings, {})]
147+
148+
result = []
149+
for subdivision in subdivisions:
150+
try:
151+
group_indices = [int(i.strip()) - 1 for i in subdivision["group"].split(',')]
152+
group = [findings[i] for i in group_indices if i < len(findings)]
153+
meta_info = {"reason": subdivision.get("reason", "")}
154+
if len(group) == 1:
155+
continue # Skip single-element groups for *aggregated* solutions
156+
result.append((group, meta_info))
157+
except ValueError:
158+
logger.error(f"Failed to parse group indices: {subdivision['group']}")
159+
continue
160+
except KeyError:
161+
logger.error("Unexpected subdivision structure")
162+
continue
163+
164+
return result
165+
166+
def _get_aggregated_solution_prompt(self, findings: List[Finding], meta_info: Dict) -> str:
167+
findings_str = self._get_findings_str_for_aggregation(findings, details=True)
168+
return OPENAI_AGGREGATED_SOLUTION_TEMPLATE.format(
169+
data=findings_str,
170+
meta_info=meta_info.get("reason", "")
171+
)
172+
173+
def _process_aggregated_solution_response(self, response: Dict[str, str]) -> str:
174+
if "response" not in response:
175+
logger.warning("Failed to generate an aggregated solution")
176+
return ""
177+
return clean(response["response"], llm_service=self)
178+
117179
def convert_dict_to_str(self, data: Dict) -> str:
118180
"""
119181
Convert a dictionary to a string representation.
@@ -141,7 +203,7 @@ def combine_descriptions(self, descriptions: List[str]) -> str:
141203
if len(descriptions) <= 1:
142204
return descriptions[0] if descriptions else ""
143205

144-
prompt = COMBINE_DESCRIPTIONS_TEMPLATE.format(data=descriptions)
206+
prompt = OPENAI_COMBINE_DESCRIPTIONS_TEMPLATE.format(data=descriptions)
145207

146208
response = self.generate(prompt)
147209
if "response" not in response:

src/ai/LLM/Strategies/ollama_prompts.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def answer_in_json_prompt(key: str) -> str:
9595
"1. Summary: A brief overview of the core security challenges (1-2 sentences)\n"
9696
"2. Strategic Solution: A high-level approach to address the underlying issues (3-5 key points)\n"
9797
"3. Implementation Guidance: General steps for putting the strategy into action\n"
98-
"4. Long-term Considerations: Suggestions for ongoing improvement and risk mitigation\n\n"
98+
"4. Long-term Considerations: Suggestions for ongoing improvement and risk mitigation. Give first steps or initial research that could lay a foundation.\n\n"
9999
"You may use Markdown formatting in your response to improve readability.\n"
100100
f"{answer_in_json_prompt('aggregated_solution')}"
101101
"Findings:\n{data}"

0 commit comments

Comments
 (0)