From b8f600f1db3e0455ef8d821587cadb846abfd179 Mon Sep 17 00:00:00 2001 From: Laura Schauer Date: Wed, 17 Jul 2024 10:44:27 +0200 Subject: [PATCH] Adds commit classification rule (#397) This PR adds a new rule using the `LLMService`. It sends the diff of a commit to the LLM and asks if this commit is security relevant or not. Relevance of the rule is set to 32 for now, but this value can be adjusted after evaluation. Thanks to @tommasoaiello --- prospector/llm/llm_service.py | 54 ++++++++++++++++++- prospector/llm/prompts/classify_commit.py | 16 ++++++ .../get_repository_url.py} | 0 prospector/rules/rules.py | 16 +++++- prospector/rules/rules_test.py | 24 ++++++--- 5 files changed, 102 insertions(+), 8 deletions(-) create mode 100644 prospector/llm/prompts/classify_commit.py rename prospector/llm/{prompts.py => prompts/get_repository_url.py} (100%) diff --git a/prospector/llm/llm_service.py b/prospector/llm/llm_service.py index cbc4e69e4..685bf79a0 100644 --- a/prospector/llm/llm_service.py +++ b/prospector/llm/llm_service.py @@ -3,9 +3,11 @@ import validators from langchain_core.language_models.llms import LLM from langchain_core.output_parsers import StrOutputParser +from requests import HTTPError from llm.instantiation import create_model_instance -from llm.prompts import prompt_best_guess +from llm.prompts.classify_commit import zero_shot as cc_zero_shot +from llm.prompts.get_repository_url import prompt_best_guess from log.logger import logger from util.config_parser import LLMServiceConfig from util.singleton import Singleton @@ -74,3 +76,53 @@ def get_repository_url(self, advisory_description, advisory_references) -> str: raise RuntimeError(f"Prompt-model chain could not be invoked: {e}") return url + + def classify_commit( + self, diff: str, repository_name: str, commit_message: str + ) -> bool: + """Ask an LLM whether a commit is security relevant or not. The response will be either True or False. + + Args: + candidate (Commit): The commit to input into the LLM + + Returns: + True if the commit is deemed security relevant, False if not. + + Raises: + ValueError if there is an error in the model invocation or the response was not valid. + """ + try: + chain = cc_zero_shot | self.model | StrOutputParser() + + is_relevant = chain.invoke( + { + "diff": diff, + "repository_name": repository_name, + "commit_message": commit_message, + } + ) + logger.info(f"LLM returned is_relevant={is_relevant}") + + except HTTPError as e: + # if the diff is too big, a 400 error is returned -> silently ignore by returning False for this commit + status_code = e.response.status_code + if status_code == 400: + return False + raise RuntimeError(f"Prompt-model chain could not be invoked: {e}") + except Exception as e: + raise RuntimeError(f"Prompt-model chain could not be invoked: {e}") + + if is_relevant in [ + "True", + "ANSWER:True", + "```ANSWER:True```", + ]: + return True + elif is_relevant in [ + "False", + "ANSWER:False", + "```ANSWER:False```", + ]: + return False + else: + raise RuntimeError(f"The model returned an invalid response: {is_relevant}") diff --git a/prospector/llm/prompts/classify_commit.py b/prospector/llm/prompts/classify_commit.py new file mode 100644 index 000000000..80a99afe9 --- /dev/null +++ b/prospector/llm/prompts/classify_commit.py @@ -0,0 +1,16 @@ +from langchain.prompts import PromptTemplate + +zero_shot = PromptTemplate.from_template( + """Is the following commit security relevant or not? +Please provide the output as a boolean value, either True or False. +If it is security relevant just answer True otherwise answer False. Do not return anything else. + +To provide you with some context, the name of the repository is: {repository_name}, and the +commit message is: {commit_message}. + +Finally, here is the diff of the commit: +{diff}\n + + +Your answer:\n""" +) diff --git a/prospector/llm/prompts.py b/prospector/llm/prompts/get_repository_url.py similarity index 100% rename from prospector/llm/prompts.py rename to prospector/llm/prompts/get_repository_url.py diff --git a/prospector/rules/rules.py b/prospector/rules/rules.py index 80496c812..2ba5a16e9 100644 --- a/prospector/rules/rules.py +++ b/prospector/rules/rules.py @@ -413,6 +413,18 @@ def apply(self, candidate: Commit, advisory_record: AdvisoryRecord): return False +class CommitIsSecurityRelevant(Rule): + """Matches commits that are deemed security relevant by the commit classification service.""" + + def apply( + self, + candidate: Commit, + ) -> bool: + return LLMService().classify_commit( + candidate.diff, candidate.repository, candidate.message + ) + + RULES_PHASE_1: List[Rule] = [ VulnIdInMessage("VULN_ID_IN_MESSAGE", 64), # CommitMentionedInAdv("COMMIT_IN_ADVISORY", 64), @@ -433,4 +445,6 @@ def apply(self, candidate: Commit, advisory_record: AdvisoryRecord): CommitHasTwins("COMMIT_HAS_TWINS", 2), ] -RULES_PHASE_2: List[Rule] = [] +RULES_PHASE_2: List[Rule] = [ + CommitIsSecurityRelevant("COMMIT_IS_SECURITY_RELEVANT", 32) +] diff --git a/prospector/rules/rules_test.py b/prospector/rules/rules_test.py index 230c351e0..93c246ef4 100644 --- a/prospector/rules/rules_test.py +++ b/prospector/rules/rules_test.py @@ -89,7 +89,9 @@ def candidates(): changed_files={ "core/src/main/java/org/apache/cxf/workqueue/AutomaticWorkQueueImpl.java" }, - minhash=get_encoded_minhash(get_msg("Insecure deserialization", 50)), + minhash=get_encoded_minhash( + get_msg("Insecure deserialization", 50) + ), ), # TODO: Not matched by existing tests: GHSecurityAdvInMessage, ReferencesBug, ChangesRelevantCode, TwinMentionedInAdv, VulnIdInLinkedIssue, SecurityKeywordInLinkedGhIssue, SecurityKeywordInLinkedBug, CrossReferencedBug, CrossReferencedGh, CommitHasTwins, ChangesRelevantFiles, CommitMentionedInAdv, RelevantWordsInMessage ] @@ -109,7 +111,9 @@ def advisory_record(): ) -def test_apply_phase_1_rules(candidates: List[Commit], advisory_record: AdvisoryRecord): +def test_apply_phase_1_rules( + candidates: List[Commit], advisory_record: AdvisoryRecord +): annotated_candidates = apply_rules( candidates, advisory_record, enabled_rules=enabled_rules_from_config ) @@ -117,7 +121,9 @@ def test_apply_phase_1_rules(candidates: List[Commit], advisory_record: Advisory # Repo 5: Should match: AdvKeywordsInFiles, SecurityKeywordsInMsg, CommitMentionedInReference assert len(annotated_candidates[0].matched_rules) == 3 - matched_rules_names = [item["id"] for item in annotated_candidates[0].matched_rules] + matched_rules_names = [ + item["id"] for item in annotated_candidates[0].matched_rules + ] assert "ADV_KEYWORDS_IN_FILES" in matched_rules_names assert "COMMIT_IN_REFERENCE" in matched_rules_names assert "SEC_KEYWORDS_IN_MESSAGE" in matched_rules_names @@ -125,21 +131,27 @@ def test_apply_phase_1_rules(candidates: List[Commit], advisory_record: Advisory # Repo 1: Should match: VulnIdInMessage, ReferencesGhIssue assert len(annotated_candidates[1].matched_rules) == 2 - matched_rules_names = [item["id"] for item in annotated_candidates[1].matched_rules] + matched_rules_names = [ + item["id"] for item in annotated_candidates[1].matched_rules + ] assert "VULN_ID_IN_MESSAGE" in matched_rules_names assert "GITHUB_ISSUE_IN_MESSAGE" in matched_rules_names # Repo 3: Should match: VulnIdInMessage, ReferencesGhIssue assert len(annotated_candidates[2].matched_rules) == 2 - matched_rules_names = [item["id"] for item in annotated_candidates[2].matched_rules] + matched_rules_names = [ + item["id"] for item in annotated_candidates[2].matched_rules + ] assert "VULN_ID_IN_MESSAGE" in matched_rules_names assert "GITHUB_ISSUE_IN_MESSAGE" in matched_rules_names # Repo 4: Should match: SecurityKeywordsInMsg assert len(annotated_candidates[3].matched_rules) == 1 - matched_rules_names = [item["id"] for item in annotated_candidates[3].matched_rules] + matched_rules_names = [ + item["id"] for item in annotated_candidates[3].matched_rules + ] assert "SEC_KEYWORDS_IN_MESSAGE" in matched_rules_names # Repo 2: Matches nothing