Skip to content

Commit 730d3dd

Browse files
committed
fet: added structure for aggregated solution
1 parent da61429 commit 730d3dd

8 files changed

+242
-14
lines changed

src/ai/Grouping/FindingBatcher.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from typing import List
2+
from collections import defaultdict
3+
from data.Finding import Finding
4+
from utils.token_utils import fits_in_context
5+
6+
7+
class FindingBatcher:
8+
def __init__(self, llm_service):
9+
self.llm_service = llm_service
10+
self.category_attributes = [
11+
'security_aspect', 'affected_component', 'technology_stack',
12+
'remediation_type', 'severity_level', 'compliance', 'environment'
13+
]
14+
15+
def create_batches(self, findings: List[Finding]) -> List[List[Finding]]:
16+
"""Create batches of findings that fit within the LLM's context."""
17+
if self._fits_in_context(findings, include_solution=True):
18+
return [findings]
19+
elif self._fits_in_context(findings, include_solution=False):
20+
return [self._strip_solutions(findings)]
21+
return self._recursive_batch(findings)
22+
23+
def _recursive_batch(self, findings: List[Finding], depth: int = 0) -> List[List[Finding]]:
24+
"""Recursively batch findings based on category attributes."""
25+
if depth >= len(self.category_attributes):
26+
return self._final_split(findings)
27+
28+
grouped = self._group_by_attribute(findings, self.category_attributes[depth])
29+
batches = []
30+
31+
for group in grouped.values():
32+
if len(group) == 1:
33+
# batches.append(group)
34+
pass # Remove single findings, as they are not useful for *aggregated* solutions
35+
elif self._fits_in_context(group, include_solution=True):
36+
batches.append(group)
37+
elif self._fits_in_context(group, include_solution=False):
38+
batches.append(self._strip_solutions(group))
39+
else:
40+
batches.extend(self._recursive_batch(group, depth + 1))
41+
42+
return batches
43+
44+
def _group_by_attribute(self, findings: List[Finding], attribute: str) -> dict:
45+
"""Group findings by a specific category attribute."""
46+
grouped = defaultdict(list)
47+
for finding in findings:
48+
if finding.category and getattr(finding.category, attribute):
49+
key = getattr(finding.category, attribute).value
50+
else:
51+
key = 'unknown'
52+
grouped[key].append(finding)
53+
return grouped
54+
55+
def _final_split(self, findings: List[Finding]) -> List[List[Finding]]:
56+
"""Split findings when all category attributes have been exhausted."""
57+
batches = []
58+
current_batch = []
59+
60+
for finding in findings:
61+
current_batch.append(finding)
62+
if self._fits_in_context(current_batch, include_solution=True):
63+
continue
64+
elif self._fits_in_context(current_batch, include_solution=False):
65+
current_batch = self._strip_solutions(current_batch)
66+
else:
67+
# If adding this finding exceeds the context, start a new batch
68+
batches.append(current_batch[:-1])
69+
current_batch = [finding]
70+
71+
if current_batch:
72+
batches.append(current_batch)
73+
74+
return batches
75+
76+
def _fits_in_context(self, findings: List[Finding], include_solution: bool) -> bool:
77+
"""Check if a list of findings fits within the LLM's context."""
78+
content = "\n".join(self._finding_to_string(f, include_solution) for f in findings)
79+
return fits_in_context(content, self.llm_service)
80+
81+
def _finding_to_string(self, finding: Finding, include_solution: bool) -> str:
82+
"""Convert a finding to a string representation."""
83+
content = f"Description: {finding.description}"
84+
if include_solution and finding.solution:
85+
content += f"\nSolution: {finding.solution.short_description}"
86+
return content
87+
88+
def _strip_solutions(self, findings: List[Finding]) -> List[Finding]:
89+
"""Remove solutions from a list of findings."""
90+
return [Finding(**{**f.dict(), 'solution': None}) for f in findings]

src/ai/Grouping/FindingGrouper.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import List
2+
3+
from tqdm import tqdm
4+
5+
from ai.Grouping.FindingBatcher import FindingBatcher
6+
from ai.LLM.BaseLLMService import BaseLLMService
7+
from data.AggregatedSolution import AggregatedSolution
8+
from data.VulnerabilityReport import VulnerabilityReport
9+
10+
11+
class FindingGrouper:
12+
def __init__(self, vulnerability_report: VulnerabilityReport, llm_service: BaseLLMService):
13+
self.vulnerability_report = vulnerability_report
14+
self.llm_service = llm_service
15+
self.batcher = FindingBatcher(llm_service)
16+
self.batches = self.batcher.create_batches(vulnerability_report.get_findings())
17+
self.aggregated_solutions: List[AggregatedSolution] = []
18+
19+
def generate_aggregated_solutions(self):
20+
for batch in tqdm(self.batches, desc="Generating Aggregated Solutions"):
21+
result_list = self.llm_service.generate_aggregated_solution(batch)
22+
for result in result_list:
23+
self.aggregated_solutions.append(AggregatedSolution(result[1], result[0], result[2]))
24+
self.vulnerability_report.set_aggregated_solutions(self.aggregated_solutions)

src/ai/LLM/BaseLLMService.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from enum import Enum
3-
from typing import Dict, Optional, List, Union
3+
from typing import Dict, Optional, List, Union, Tuple
44
import logging
55

66
from data.Finding import Finding
@@ -94,6 +94,54 @@ def _get_search_terms_prompt(self, finding: Finding) -> str:
9494
def _process_search_terms_response(self, response: Dict[str, str], finding: Finding) -> str:
9595
pass
9696

97+
def generate_aggregated_solution(self, findings: List[Finding]) -> List[Tuple[str, List[Finding], Dict]]:
98+
"""
99+
Generate an aggregated solution for a group of findings.
100+
101+
Args:
102+
findings (List[Finding]): The findings to generate a solution for.
103+
104+
Returns:
105+
List[Tuple[str, List[Finding], Dict]]: The generated solution, the findings it applies to, and any additional metadata
106+
"""
107+
finding_groups = self._subdivide_finding_group(findings)
108+
if len(finding_groups) < 1:
109+
return [] # No suitable groups found
110+
111+
results = []
112+
113+
for group, meta_info in finding_groups:
114+
prompt = self._get_aggregated_solution_prompt(group, meta_info)
115+
response = self.generate(prompt)
116+
solution = self._process_aggregated_solution_response(response)
117+
118+
if solution:
119+
results.append((solution, group, meta_info))
120+
121+
return results
122+
123+
124+
def _subdivide_finding_group(self, findings: List[Finding]) -> List[Tuple[List[Finding], Dict]]:
125+
prompt = self._get_subdivision_prompt(findings)
126+
response = self.generate(prompt)
127+
return self._process_subdivision_response(response, findings)
128+
129+
@abstractmethod
130+
def _get_subdivision_prompt(self, findings: List[Finding]) -> str:
131+
pass
132+
133+
@abstractmethod
134+
def _process_subdivision_response(self, response: Dict, findings: List[Finding]) -> List[Tuple[List[Finding], Dict]]:
135+
pass
136+
137+
@abstractmethod
138+
def _get_aggregated_solution_prompt(self, findings: List[Finding], meta_info: Dict) -> str:
139+
pass
140+
141+
@abstractmethod
142+
def _process_aggregated_solution_response(self, response: Dict) -> str:
143+
pass
144+
97145
@abstractmethod
98146
def convert_dict_to_str(self, data) -> str:
99147
pass

src/ai/LLM/LLMServiceStrategy.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import Dict, Optional, List, Union
2+
from typing import Dict, Optional, List, Union, Tuple
33

44
from ai.LLM.BaseLLMService import BaseLLMService
55
from data.Finding import Finding
@@ -21,6 +21,10 @@ def get_model_name(self) -> str:
2121
"""Get the name of the current LLM model."""
2222
return self.llm_service.get_model_name()
2323

24+
def get_context_size(self) -> int:
25+
"""Get the context size of the current LLM service."""
26+
return self.llm_service.get_context_size()
27+
2428
def get_url(self) -> str:
2529
"""Get the URL associated with the current LLM service."""
2630
return self.llm_service.get_url()
@@ -88,6 +92,18 @@ def get_search_terms(self, finding: Finding) -> str:
8892
"""
8993
return self.llm_service.get_search_terms(finding)
9094

95+
def generate_aggregated_solution(self, findings: List[Finding]) -> List[Tuple[str, List[Finding], Dict]]:
96+
"""
97+
Generate an aggregated solution for a group of findings.
98+
99+
Args:
100+
findings (List[Finding]): The findings to generate a solution for.
101+
102+
Returns:
103+
List[Tuple[str, List[Finding], Dict]]: The generated solution, the findings it applies to, and any additional metadata
104+
"""
105+
return self.llm_service.generate_aggregated_solution(findings)
106+
91107
def convert_dict_to_str(self, data: Dict) -> str:
92108
"""
93109
Convert a dictionary to a string representation.

src/data/AggregatedSolution.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import List
2+
3+
from data.Finding import Finding
4+
from db.base import BaseModel
5+
6+
7+
class AggregatedSolution:
8+
findings: List[Finding] = None
9+
solution: str = ""
10+
metadata: dict = {}
11+
12+
def __init__(self, findings: List[Finding], solution: str, metadata=None):
13+
self.findings = findings
14+
self.solution = solution
15+
self.metadata = metadata
16+
17+
def __str__(self):
18+
return self.solution
19+
20+
def to_dict(self):
21+
return {
22+
"findings": [finding.to_dict() for finding in self.findings],
23+
"solution": self.solution,
24+
"metadata": self.metadata
25+
}
26+
27+
def to_html(self):
28+
return f"<p>{self.solution}</p>"

src/data/Categories.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,13 @@ class Environment(Enum):
8484

8585

8686
class Category(BaseModel):
87-
technology_stack: Optional[List[TechnologyStack]] = None
88-
security_aspect: Optional[List[SecurityAspect]] = None
87+
technology_stack: Optional[TechnologyStack] = None
88+
security_aspect: Optional[SecurityAspect] = None
8989
severity_level: Optional[SeverityLevel] = None
90-
remediation_type: Optional[List[RemediationType]] = None
91-
affected_component: Optional[List[AffectedComponent]] = None
92-
compliance: Optional[List[Compliance]] = None
93-
environment: Optional[List[Environment]] = None
90+
remediation_type: Optional[RemediationType] = None
91+
affected_component: Optional[AffectedComponent] = None
92+
compliance: Optional[Compliance] = None
93+
environment: Optional[Environment] = None
9494

9595
def __str__(self):
9696
my_str = ""

src/data/VulnerabilityReport.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import json
2+
from typing import List
23

34
from tqdm import tqdm
45
from random import shuffle
56

7+
from data.AggregatedSolution import AggregatedSolution
68
from data.Finding import Finding
79
from ai.LLM.LLMServiceStrategy import LLMServiceStrategy
810
from ai.Clustering.AgglomerativeClusterer import AgglomerativeClusterer
@@ -17,6 +19,7 @@
1719
class VulnerabilityReport:
1820
def __init__(self):
1921
self.findings: list[Finding] = []
22+
self.aggregated_solutions: List[AggregatedSolution] = []
2023

2124
def set_llm_service(self, llm_service: "LLMServiceStrategy"):
2225
"""
@@ -68,6 +71,12 @@ def add_solution(self, long=True, short=True, search_term=True):
6871
finding.generate_solution(long, short, search_term)
6972
return self
7073

74+
def set_aggregated_solutions(self, aggregated_solutions: List[AggregatedSolution]):
75+
self.aggregated_solutions = aggregated_solutions
76+
77+
def get_aggregated_solutions(self) -> List[AggregatedSolution]:
78+
return self.aggregated_solutions
79+
7180
def sort(self, by: str = "severity", reverse: bool = True):
7281
"""
7382
This function sorts the findings by severity or priority.
@@ -87,13 +96,24 @@ def sort(self, by: str = "severity", reverse: bool = True):
8796
return self
8897

8998
def to_dict(self):
90-
return [f.to_dict() for f in self.findings]
99+
findings = [f.to_dict() for f in self.findings]
100+
if len(self.get_aggregated_solutions()) > 0:
101+
aggregated_solutions = [f.to_dict() for f in self.get_aggregated_solutions()]
102+
return {"findings": findings, "aggregated_solutions": aggregated_solutions}
103+
return {"findings": findings}
91104

92105
def __str__(self):
93-
return "\n\n".join([str(f) for f in self.findings])
106+
findings_str = "\n".join([str(f) for f in self.findings])
107+
if len(self.get_aggregated_solutions()) > 0:
108+
aggregated_solutions_str = "\n".join([str(f) for f in self.get_aggregated_solutions()])
109+
return findings_str + "\n\n" + aggregated_solutions_str
110+
return findings_str
94111

95112
def to_html(self, table=False):
96-
return "".join([f.to_html(table) for f in self.findings])
113+
my_str = "<br/>".join([f.to_html(table) for f in self.findings])
114+
if len(self.get_aggregated_solutions()) > 0:
115+
my_str += "<br/><br/>" + "<br/>".join([f.to_html() for f in self.get_aggregated_solutions()])
116+
return my_str
97117

98118
def export_to_json(self, filename="VulnerabilityReport.json"):
99119
"""

src/utils/token_utils.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import re
2+
3+
from ai.LLM.BaseLLMService import BaseLLMService
24
from config import config
35

46

@@ -50,7 +52,7 @@ def estimate_tokens(text):
5052
return total_tokens
5153

5254

53-
def fits_in_context(text):
55+
def fits_in_context(text, llm_service: BaseLLMService = None):
5456
"""
5557
Check if the given text fits within the maximum context size.
5658
@@ -64,7 +66,7 @@ def fits_in_context(text):
6466
bool: True if the text fits within the context, False otherwise.
6567
"""
6668
estimated_tokens = estimate_tokens(text)
67-
return estimated_tokens <= config.max_context_length
69+
return bool(llm_service) and estimated_tokens <= llm_service.get_context_size()
6870

6971

7072
# Example usage
@@ -76,4 +78,4 @@ def fits_in_context(text):
7678
if fits_in_context(sample_text):
7779
print("The text fits within the maximum context.")
7880
else:
79-
print("The text exceeds the maximum context size.")
81+
print("The text exceeds the maximum context size.")

0 commit comments

Comments
 (0)