From 25abf959390591991e20f0cb64c4685d458ed987 Mon Sep 17 00:00:00 2001 From: Dylan Pulver <35541198+dylanpulver@users.noreply.github.com> Date: Mon, 9 Dec 2024 17:48:23 -0500 Subject: [PATCH] feature/cve-data-filter-flag (#643) --- safety/scan/command.py | 121 ++++++++++++++++++++++++++++++++++++- tests/scan/test_command.py | 4 +- 2 files changed, 120 insertions(+), 5 deletions(-) diff --git a/safety/scan/command.py b/safety/scan/command.py index 2c74bafd..86392069 100644 --- a/safety/scan/command.py +++ b/safety/scan/command.py @@ -2,6 +2,7 @@ import logging from pathlib import Path +import json import sys from typing import Any, Dict, List, Optional, Set, Tuple from typing_extensions import Annotated @@ -49,7 +50,9 @@ class ScannableEcosystems(Enum): def process_report( obj: Any, console: Console, report: ReportModel, output: str, - save_as: Optional[Tuple[str, Path]], **kwargs + save_as: Optional[Tuple[str, Path]], detailed_output: bool = False, + filter_keys: Optional[List[str]] = None, + **kwargs ) -> Optional[str]: """ Processes and outputs the report based on the given parameters. @@ -60,6 +63,8 @@ def process_report( report (ReportModel): The report model. output (str): The output format. save_as (Optional[Tuple[str, Path]]): The save-as format and path. + detailed_output (bool): Whether detailed output is enabled. + filter_keys (Optional[List[str]]): Keys to filter from the JSON output. kwargs: Additional keyword arguments. Returns: @@ -162,6 +167,12 @@ def process_report( if output is ScanOutput.JSON or ScanOutput.is_format(output, ScanOutput.SPDX): if output is ScanOutput.JSON: + if detailed_output: + report_to_output = add_cve_details_to_report(report_to_output, obj.project.files) + + if filter_keys: + report_to_output = filter_json_keys(report_to_output, filter_keys) + kwargs = {"json": report_to_output} else: kwargs = {"data": report_to_output} @@ -175,6 +186,95 @@ def process_report( return report_url +def filter_json_keys(json_string: str, keys: List[str]) -> str: + """ + Filters the given JSON string by the specified top-level keys. + + Args: + json_string (str): The JSON string to filter. + keys (List[str]): List of top-level keys to include in the output. + + Returns: + str: A JSON string containing only the specified keys. + """ + report_dict = json.loads(json_string) + filtered_data = {key: report_dict[key] for key in keys if key in report_dict} + return json.dumps(filtered_data, indent=4) + + +def filter_valid_cves(vulnerabilities: List[Any]) -> List[Dict[str, Any]]: + """ + Filters and returns valid CVE details from a list of vulnerabilities. + + Args: + vulnerabilities (List[Any]): A list of vulnerabilities, which may include invalid data types. + + Returns: + List[Dict[str, Any]]: A list of filtered CVE details that are either strings or dictionaries. + """ + return [ + cve for cve in vulnerabilities if isinstance(cve, str) or isinstance(cve, dict) + ] #type:ignore + + +def sort_cve_data(cve_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Sorts CVE details by severity in descending order. + + Args: + cve_data (List[Dict[str, Any]]): A list of CVE details dictionaries, each containing a 'severity' key. + + Returns: + List[Dict[str, Any]]: The sorted list of CVE details, prioritized by severity (e.g., CRITICAL > HIGH > MEDIUM). + """ + severity_order = {key.name: id for (id, key) in enumerate(VulnerabilitySeverityLabels)} + return sorted(cve_data, key=lambda x: severity_order.get(x["severity"].upper(), 0), reverse=True) + + +def generate_cve_details(files: List[FileModel]) -> List[Dict[str, Any]]: + """ + Generate CVE details from the scanned files. + + Args: + files (List[FileModel]): List of scanned file models. + + Returns: + List[Dict[str, Any]]: List of CVE details sorted by severity. + """ + cve_data = [] + for file in files: + for spec in file.results.get_affected_specifications(): + for vuln in spec.vulnerabilities: + if vuln.CVE: + cve_data.append({ + "package": spec.name, + "affected_version": str(spec.specifier), + "safety_vulnerability_id": vuln.vulnerability_id, + "CVE": filter_valid_cves(vuln.CVE), + "more_info": vuln.more_info_url, + "advisory": vuln.advisory, + "severity": vuln.severity.cvssv3.get("base_severity", "Unknown") if vuln.severity and vuln.severity.cvssv3 else "Unknown", + }) + return sort_cve_data(cve_data) + + +def add_cve_details_to_report(report_to_output: str, files: List[FileModel]) -> str: + """ + Add CVE details to the JSON report output. + + Args: + report_to_output (str): The current JSON string of the report. + files (List[FileModel]): List of scanned files containing vulnerability data. + + Returns: + str: The updated JSON string with CVE details added. + """ + cve_details = generate_cve_details(files) + report_dict = json.loads(report_to_output) + report_dict["cve_details"] = cve_details + return json.dumps(report_dict) + + def generate_updates_arguments() -> List: """ Generates a list of file types and update limits for apply fixes. @@ -250,7 +350,11 @@ def scan(ctx: typer.Context, typer.Option("--apply-fixes", help=SCAN_APPLY_FIXES, show_default=False) - ] = False + ] = False, + filter_keys: Annotated[ + Optional[List[str]], + typer.Option("--filter", help="Filter output by specific top-level JSON keys.") + ] = None, ): """ Scans a project (defaulted to the current directory) for supply-chain security and configuration issues @@ -465,7 +569,18 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int: ignored_vulns_data=ignored_vulns_data ) - report_url = process_report(ctx.obj, console, report, **{**ctx.params}) + report_url = process_report( + obj=ctx.obj, + console=console, + report=report, + output=output, + save_as=save_as if save_as and all(save_as) else None, + detailed_output=detailed_output, + filter_keys=filter_keys, + **{k: v for k, v in ctx.params.items() if k not in {"detailed_output", "output", "save_as", "filter_keys"}} +) + + project_url = f"{SAFETY_PLATFORM_URL}{ctx.obj.project.url_path}" if apply_updates: diff --git a/tests/scan/test_command.py b/tests/scan/test_command.py index 48df61c8..e8d9a9ef 100644 --- a/tests/scan/test_command.py +++ b/tests/scan/test_command.py @@ -1,5 +1,6 @@ import os import unittest + from unittest.mock import patch, Mock from click.testing import CliRunner from safety.cli import cli @@ -13,7 +14,7 @@ def setUp(self): self.runner = CliRunner(mix_stderr=False) self.dirname = os.path.dirname(__file__) - def test_scan(self): + def test_scan(self): result = self.runner.invoke(cli, ["--stage", "cicd", "scan", "--target", self.dirname, "--output", "json"]) self.assertEqual(result.exit_code, 1) @@ -22,4 +23,3 @@ def test_scan(self): result = self.runner.invoke(cli, ["--stage", "cicd", "scan", "--target", self.dirname, "--output", "screen"]) self.assertEqual(result.exit_code, 1) -