Skip to content

Commit 2c1e7c2

Browse files
Merge pull request #51 from DigitalProductInnovationAndDevelopment/feat/integrate-aggregated-solution
add aggregated solution
2 parents edb4f79 + f5dd20e commit 2c1e7c2

25 files changed

+797
-271
lines changed

.github/workflows/docs.yml

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
name: API Docs
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
branches:
9+
- main
10+
jobs:
11+
generate-docs:
12+
runs-on: ubuntu-latest
13+
steps:
14+
- name: Checkout repository
15+
uses: actions/checkout@v3
16+
17+
- name: Set up Python 3.11
18+
uses: actions/setup-python@v2
19+
with:
20+
python-version: 3.11
21+
22+
- name: Install dependencies
23+
run: |
24+
pip install pipenv
25+
pipenv install
26+
27+
- name: Generate docs
28+
run: |
29+
pipenv run python src/extract-docs.py
30+
31+
- name: Save docs as artifact
32+
uses: actions/upload-artifact@v2
33+
with:
34+
name: docs
35+
path: .docs
36+
if-no-files-found: error

README.md

+6-5
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ After starting the application, you can access the API at `http://localhost:8000
5454
### Docker
5555

5656
building the images
57+
5758
```bash
5859
docker compose build
5960
```
@@ -64,23 +65,23 @@ To run the code within Docker, use the following command in the root directory o
6465
docker compose up
6566
```
6667

67-
68-
6968
Add the `-d` flag to run the containers in the background: `docker compose up -d`.
7069

71-
72-
73-
7470
### Available Routes
7571

7672
Currently, these routes are generated by fastapi.
7773

7874
```
75+
HEAD, GET /openapi.json
76+
HEAD, GET /docs
77+
HEAD, GET /docs/oauth2-redirect
78+
HEAD, GET /redoc
7979
GET /api/v1/tasks/
8080
DELETE /api/v1/tasks/{task_id}
8181
DELETE /api/v1/tasks/
8282
GET /api/v1/tasks/{task_id}/status
8383
POST /api/v1/recommendations/
84+
POST /api/v1/recommendations/aggregated
8485
POST /api/v1/upload/
8586
GET /
8687
```

src/ai/Grouping/FindingGrouper.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010

1111
class FindingGrouper:
12-
def __init__(self, vulnerability_report: VulnerabilityReport, llm_service: BaseLLMService):
12+
def __init__(
13+
self, vulnerability_report: VulnerabilityReport, llm_service: BaseLLMService
14+
):
1315
self.vulnerability_report = vulnerability_report
1416
self.llm_service = llm_service
1517
self.batcher = FindingBatcher(llm_service)
@@ -20,5 +22,7 @@ def generate_aggregated_solutions(self):
2022
for batch in tqdm(self.batches, desc="Generating Aggregated Solutions"):
2123
result_list = self.llm_service.generate_aggregated_solution(batch)
2224
for result in result_list:
23-
self.aggregated_solutions.append(AggregatedSolution(result[1], result[0], result[2])) # Solution, Findings, Metadata
25+
self.aggregated_solutions.append(
26+
AggregatedSolution().from_result(result[1], result[0], result[2])
27+
) # Solution, Findings, Metadata
2428
self.vulnerability_report.set_aggregated_solutions(self.aggregated_solutions)

src/app.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import ai.LLM.Strategies.OLLAMAService
88
from config import config
99

10+
1011
import routes
1112
import routes.v1.recommendations
1213
import routes.v1.task

src/config.py

+15-41
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,33 @@
11
from pydantic import (
2-
BaseModel,
32
Field,
4-
RedisDsn,
53
ValidationInfo,
6-
model_validator,
7-
root_validator,
84
)
95

106
from pydantic_settings import BaseSettings, SettingsConfigDict
11-
from pydantic import Field, RedisDsn, field_validator
7+
from pydantic import Field, field_validator
128

139
from typing import Optional
1410

1511

16-
class SubModel(BaseModel):
17-
foo: str = "bar"
18-
apple: int = 1
19-
20-
2112
class Config(BaseSettings):
2213
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
2314

24-
ollama_url: str = Field(
25-
json_schema_extra="OLLAMA_URL", default="http://localhost:11434"
26-
)
27-
ollama_model: str = Field(json_schema_extra="OLLAMA_MODEL", default="phi3:mini")
15+
ollama_url: str = Field(default="http://localhost:11434")
16+
ollama_model: str = Field(default="phi3:mini")
2817

29-
ai_strategy: Optional[str] = Field(
30-
json_schema_extra="AI_STRATEGY", default="OLLAMA"
31-
)
32-
anthropic_api_key: Optional[str] = Field(
33-
json_schema_extra="ANTHROPIC_API_KEY", default=None
34-
)
35-
openai_api_key: Optional[str] = Field(
36-
json_schema_extra="OPENAI_API_KEY", default=None
37-
)
18+
ai_strategy: Optional[str] = Field(default="OLLAMA")
19+
anthropic_api_key: Optional[str] = Field(default=None)
20+
openai_api_key: Optional[str] = Field(default=None)
3821

39-
postgres_server: str = Field(
40-
json_schema_extra="POSTGRES_SERVER", default="localhost"
41-
)
42-
postgres_port: int = Field(json_schema_extra="POSTGRES_PORT", default=5432)
43-
postgres_db: str = Field(json_schema_extra="POSTGRES_DB", default="app")
44-
postgres_user: str = Field(json_schema_extra="POSTGRES_USER", default="postgres")
45-
postgres_password: str = Field(
46-
json_schema_extra="POSTGRES_PASSWORD", default="postgres"
47-
)
48-
queue_processing_limit: int = Field(
49-
json_schema_extra="QUEUE_PROCESSING_LIMIT", default=10
50-
)
51-
redis_endpoint: str = Field(
52-
json_schema_extra="REDIS_ENDPOINT", default="redis://localhost:6379/0"
53-
)
54-
environment: str = Field(env="ENVIRONMENT", default="development")
55-
db_debug: bool = Field(env="DB_DEBUG", default=False)
22+
postgres_server: str = Field(default="localhost")
23+
postgres_port: int = Field(default=5432)
24+
postgres_db: str = Field(default="app")
25+
postgres_user: str = Field(default="postgres")
26+
postgres_password: str = Field(default="postgres")
27+
queue_processing_limit: int = Field(default=10)
28+
redis_endpoint: str = Field(default="redis://localhost:6379/0")
29+
environment: str = Field(default="development")
30+
db_debug: bool = Field(default=False)
5631

5732
@field_validator(
5833
"ai_strategy",
@@ -66,7 +41,6 @@ def check_ai_strategy(cls, ai_strategy, values):
6641

6742
@field_validator("openai_api_key")
6843
def check_api_key(cls, api_key, info: ValidationInfo):
69-
print(info)
7044
if info.data["ai_strategy"] == "OPENAI" and not api_key:
7145
raise ValueError("OPENAI_API_KEY is required when ai_strategy is OPENAI")
7246
return api_key

src/data/AggregatedSolution.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
from typing import List
22

33
from data.Finding import Finding
4-
from db.base import BaseModel
4+
from pydantic import BaseModel
55

66

7-
class AggregatedSolution:
7+
class AggregatedSolution(BaseModel):
88
findings: List[Finding] = None
99
solution: str = ""
1010
metadata: dict = {}
1111

12-
def __init__(self, findings: List[Finding], solution: str, metadata=None):
12+
def from_result(self, findings: List[Finding], solution: str, metadata=None):
1313
self.findings = findings
1414
self.solution = solution
1515
self.metadata = metadata
16+
return self
1617

1718
def __str__(self):
1819
return self.solution
@@ -21,8 +22,8 @@ def to_dict(self):
2122
return {
2223
"findings": [finding.to_dict() for finding in self.findings],
2324
"solution": self.solution,
24-
"metadata": self.metadata
25+
"metadata": self.metadata,
2526
}
2627

2728
def to_html(self):
28-
return f"<p>{self.solution}</p>"
29+
return f"<p>{self.solution}</p>"

src/data/Finding.py

+44-18
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
from typing import List, Set, Optional, Any, get_args
22
from enum import Enum, auto
3+
import uuid
4+
35
from pydantic import BaseModel, Field, PrivateAttr
46
from data.Solution import Solution
5-
from data.Categories import Category, TechnologyStack, SecurityAspect, SeverityLevel, RemediationType, \
6-
AffectedComponent, Compliance, Environment
7+
from data.Categories import (
8+
Category,
9+
TechnologyStack,
10+
SecurityAspect,
11+
SeverityLevel,
12+
RemediationType,
13+
AffectedComponent,
14+
Compliance,
15+
Environment,
16+
)
717
import json
818

919
import logging
@@ -12,6 +22,7 @@
1222

1323

1424
class Finding(BaseModel):
25+
id: str = Field(default_factory=lambda: f"{str(uuid.uuid4())}")
1526
title: List[str] = Field(default_factory=list)
1627
source: Set[str] = Field(default_factory=set)
1728
descriptions: List[str] = Field(default_factory=list)
@@ -21,7 +32,7 @@ class Finding(BaseModel):
2132
severity: Optional[int] = None
2233
priority: Optional[int] = None
2334
location_list: List[str] = Field(default_factory=list)
24-
category: Category = None
35+
category: Optional[Category] = None
2536
unsupervised_cluster: Optional[int] = None
2637
solution: Optional["Solution"] = None
2738
_llm_service: Optional[Any] = PrivateAttr(default=None)
@@ -35,7 +46,9 @@ def combine_descriptions(self) -> "Finding":
3546
logger.error("LLM Service not set, cannot combine descriptions.")
3647
return self
3748

38-
self.description = self.llm_service.combine_descriptions(self.descriptions, self.cve_ids, self.cwe_ids)
49+
self.description = self.llm_service.combine_descriptions(
50+
self.descriptions, self.cve_ids, self.cwe_ids
51+
)
3952
return self
4053

4154
def add_category(self) -> "Finding":
@@ -47,34 +60,45 @@ def add_category(self) -> "Finding":
4760

4861
# Classify technology stack
4962
technology_stack_options = list(TechnologyStack)
50-
self.category.technology_stack = self.llm_service.classify_kind(self, "technology_stack",
51-
technology_stack_options)
63+
self.category.technology_stack = self.llm_service.classify_kind(
64+
self, "technology_stack", technology_stack_options
65+
)
5266

5367
# Classify security aspect
5468
security_aspect_options = list(SecurityAspect)
55-
self.category.security_aspect = self.llm_service.classify_kind(self, "security_aspect", security_aspect_options)
69+
self.category.security_aspect = self.llm_service.classify_kind(
70+
self, "security_aspect", security_aspect_options
71+
)
5672

5773
# Classify severity level
5874
severity_level_options = list(SeverityLevel)
59-
self.category.severity_level = self.llm_service.classify_kind(self, "severity_level", severity_level_options)
75+
self.category.severity_level = self.llm_service.classify_kind(
76+
self, "severity_level", severity_level_options
77+
)
6078

6179
# Classify remediation type
6280
remediation_type_options = list(RemediationType)
63-
self.category.remediation_type = self.llm_service.classify_kind(self, "remediation_type",
64-
remediation_type_options)
81+
self.category.remediation_type = self.llm_service.classify_kind(
82+
self, "remediation_type", remediation_type_options
83+
)
6584

6685
# Classify affected component
6786
affected_component_options = list(AffectedComponent)
68-
self.category.affected_component = self.llm_service.classify_kind(self, "affected_component",
69-
affected_component_options)
87+
self.category.affected_component = self.llm_service.classify_kind(
88+
self, "affected_component", affected_component_options
89+
)
7090

7191
# Classify compliance
7292
compliance_options = list(Compliance)
73-
self.category.compliance = self.llm_service.classify_kind(self, "compliance", compliance_options)
93+
self.category.compliance = self.llm_service.classify_kind(
94+
self, "compliance", compliance_options
95+
)
7496

7597
# Classify environment
7698
environment_options = list(Environment)
77-
self.category.environment = self.llm_service.classify_kind(self, "environment", environment_options)
99+
self.category.environment = self.llm_service.classify_kind(
100+
self, "environment", environment_options
101+
)
78102

79103
return self
80104

@@ -213,17 +237,19 @@ def to_html(self, table=False):
213237
result += "<tr><th>Name</th><th>Value</th></tr>"
214238
result += f"<tr><td>Title</td><td>{', '.join(self.title)}</td></tr>"
215239
result += f"<tr><td>Source</td><td>{', '.join(self.source)}</td></tr>"
216-
result += (
217-
f"<tr><td>Description</td><td>{self.description}</td></tr>"
218-
)
240+
result += f"<tr><td>Description</td><td>{self.description}</td></tr>"
219241
if len(self.location_list) > 0:
220242
result += f"<tr><td>Location List</td><td>{' & '.join(map(str, self.location_list))}</td></tr>"
221243
result += f"<tr><td>CWE IDs</td><td>{', '.join(self.cwe_ids)}</td></tr>"
222244
result += f"<tr><td>CVE IDs</td><td>{', '.join(self.cve_ids)}</td></tr>"
223245
result += f"<tr><td>Severity</td><td>{self.severity}</td></tr>"
224246
result += f"<tr><td>Priority</td><td>{self.priority}</td></tr>"
225247
if self.category is not None:
226-
result += '<tr><td>Category</td><td>' + str(self.category).replace("\n", "<br />") + '</td></tr>'
248+
result += (
249+
"<tr><td>Category</td><td>"
250+
+ str(self.category).replace("\n", "<br />")
251+
+ "</td></tr>"
252+
)
227253
if self.unsupervised_cluster is not None:
228254
result += f"<tr><td>Unsupervised Cluster</td><td>{self.unsupervised_cluster}</td></tr>"
229255
result += "</table>"

0 commit comments

Comments
 (0)