Skip to content

Commit

Permalink
PATCHED_API_KEY integration (#12)
Browse files Browse the repository at this point in the history
* patched-api-key integration 

* fix gitlab comments issues

* fix vulnerability limit

* Fix PRReview by skipping some extensions
  • Loading branch information
CTY-git authored Apr 12, 2024
1 parent 5057603 commit 731cdad
Show file tree
Hide file tree
Showing 7 changed files with 363 additions and 299 deletions.
11 changes: 8 additions & 3 deletions patchwork/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
import json
import traceback
from pathlib import Path

import click
Expand Down Expand Up @@ -73,9 +74,13 @@ def cli(log: str, patchflow: str, opts: list[str], config: str | None, output: s
else:
# treat --key=value as a key-value pair
inputs[key] = value

patchflow_instance = patchflow_class(inputs)
patchflow_instance.run()
try:
patchflow_instance = patchflow_class(inputs)
patchflow_instance.run()
except Exception as e:
logger.debug(traceback.format_exc())
logger.error(f"Error running patchflow {patchflow}: {e}")
exit(1)

data_format_mapping = {
"yaml": yaml.dump,
Expand Down
23 changes: 16 additions & 7 deletions patchwork/common/client/scm.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,13 @@ def set_pr_description(self, body: str) -> None:
self._mr.save()

def create_comment(
self, path: str, body: str, start_line: int | None = None, end_line: int | None = None
self, body: str, path: str | None = None, start_line: int | None = None, end_line: int | None = None
) -> str | None:
final_body = f"{_COMMENT_MARKER} \n{PullRequestProtocol._apply_pr_template(self, body)}"
if path is None:
note = self._mr.notes.create({"body": final_body})
return f"#note_{note.get_id()}"

while True:
try:
commit = self._mr.commits().next()
Expand All @@ -161,7 +166,6 @@ def create_comment(
head_commit = diff.head_commit_sha

try:
final_body = f"{_COMMENT_MARKER} \n{PullRequestProtocol._apply_pr_template(self, body)}"
discussion = self._mr.discussions.create(
{
"body": final_body,
Expand All @@ -187,14 +191,19 @@ def create_comment(
return None

def reset_comments(self) -> None:
for discussion in self._mr.discussions.list():
for discussion in self._mr.discussions.list(iterator=True):
for note in discussion.attributes["notes"]:
if note["type"] == "DiffNote" and note["body"].startswith(_COMMENT_MARKER):
if note["body"].startswith(_COMMENT_MARKER):
discussion.notes.delete(note["id"])

def file_diffs(self) -> dict[str, str]:
files = self._mr.diffs.list()
return {file.attributes["new_path"]: file.attributes["diff"] for file in files}
diffs = self._mr.diffs.list()
latest_diff = max(diffs, key=lambda diff: diff.created_at, default=None)
if latest_diff is None:
return {}

files = self._mr.diffs.get(latest_diff.id).diffs
return {file["new_path"]: file["diff"] for file in files}


class GithubPullRequest(PullRequestProtocol):
Expand Down Expand Up @@ -336,7 +345,7 @@ def get_pr_by_url(self, url: str) -> PullRequestProtocol | None:
logger.error(f"Invalid PR URL: {url}")
return None

slug = "/".join(url_parts[-4:-2])
slug = "/".join(url_parts[-5:-3])

return self.find_pr_by_id(slug, int(pr_id))

Expand Down
29 changes: 27 additions & 2 deletions patchwork/steps/CallOpenAI/CallOpenAI.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from patchwork.logger import logger
from patchwork.step import Step

_TOKEN_URL = "https://app.patched.codes/signin"
_DEFAULT_PATCH_URL = "https://patchwork.patched.codes/v1"


class CallOpenAI(Step):
required_keys = {"openai_api_key", "prompt_file"}
required_keys = {"prompt_file"}

def __init__(self, inputs: dict):
logger.info(f"Run started {self.__class__.__name__}")
Expand All @@ -27,7 +30,29 @@ def __init__(self, inputs: dict):
self.model_args = {key[len("model_") :]: value for key, value in inputs.items() if key.startswith("model_")}
self.client_args = {key[len("client_") :]: value for key, value in inputs.items() if key.startswith("client_")}

self.openai_api_key = inputs["openai_api_key"]
openai_key = inputs.get("openai_api_key") or os.environ.get("OPENAI_API_KEY")
if openai_key is not None:
self.openai_api_key = openai_key

patched_key = inputs.get("patched_api_key")
if patched_key is not None:
self.openai_api_key = patched_key
self.client_args["base_url"] = _DEFAULT_PATCH_URL

if self.openai_api_key is None:
raise ValueError(
f"Model API key not found.\n"
f'Please login at: "{_TOKEN_URL}",\n'
"Please go to the Integration's tab and generate an API key.\n"
"Please copy the access token that is generated, "
"and add `--patched_api_key=<token>` to the command line.\n"
"\n"
"If you are using a OpenAI API Key, please set `--openai_api_key=<token>`.\n"
)

if not self.openai_api_key:
raise ValueError('Missing required data: "openai_api_key"')

self.prompt_file = Path(inputs["prompt_file"])
if not self.prompt_file.is_file():
raise ValueError(f'Unable to find Prompt file: "{self.prompt_file}"')
Expand Down
142 changes: 71 additions & 71 deletions patchwork/steps/ExtractCode/ExtractCode.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,76 @@ def resolve_artifact_location(
return None


def transform_sarif_results(
sarif_data: dict, base_path: Path, context_length: int, vulnerability_limit: int
) -> dict[tuple[str, int, int, int], list[str]]:
# Process each result in SARIF data
grouped_messages = defaultdict(list)
vulnerability_count = 0
for run_idx, run in enumerate(sarif_data.get("runs", [])):
artifact_locations = [
parse_sarif_location(base_path, artifact["location"]["uri"]) for artifact in run.get("artifacts", [])
]

for result_idx, result in enumerate(run.get("results", [])):
for location_idx, location in enumerate(result.get("locations", [])):
physical_location = location.get("physicalLocation", {})

artifact_location = physical_location.get("artifactLocation", {})
uri = resolve_artifact_location(base_path, artifact_location, artifact_locations)
if uri is None:
logger.warn(
f'Unable to find file for ".runs[{run_idx}].results[{result_idx}].locations[{location_idx}]"'
)
continue

region = physical_location.get("region", {})
start_line = region.get("startLine", 1)
end_line = region.get("endLine", start_line)
start_line = start_line - 1

# Generate file path assuming code is in the current working directory
file_path = str(uri.relative_to(base_path))

# Extract lines from the code file
logger.info(f"Extracting context for {file_path} at {start_line}:{end_line}")
try:
with open_with_chardet(file_path, "r") as file:
src = file.read()

source_lines = src.splitlines(keepends=True)
context_start, context_end = get_source_code_context(
file_path, source_lines, start_line, end_line, context_length
)

source_code_context = None
if context_start is not None and context_end is not None:
source_code_context = "".join(source_lines[context_start:context_end])

except FileNotFoundError:
context_start = None
context_end = None
source_code_context = None
logger.info(f"File not found in the current working directory: {file_path}")

if source_code_context is None:
logger.info(f"No context found for {file_path} at {start_line}:{end_line}")
continue

start = context_start if context_start is not None else start_line
end = context_end if context_end is not None else end_line

grouped_messages[(uri, start, end, source_code_context)].append(
result.get("message", {}).get("text", "")
)

vulnerability_count = vulnerability_count + 1
if 0 < vulnerability_limit <= vulnerability_count:
return grouped_messages

return grouped_messages


class ExtractCode(Step):
required_keys = {"sarif_file_path"}

Expand All @@ -112,7 +182,6 @@ def __init__(self, inputs: dict):
self.vulnerability_limit = inputs.get("vulnerability_limit", 10)

# Prepare for data extraction
self.extracted_data = []
self.extracted_code_contexts = []

def run(self) -> dict:
Expand All @@ -122,77 +191,8 @@ def run(self) -> dict:

vulnerability_count = 0
base_path = Path.cwd()
# Process each result in SARIF data
grouped_messages = defaultdict(list)
for run_idx, run in enumerate(sarif_data.get("runs", [])):
artifact_locations = [
parse_sarif_location(base_path, artifact["location"]["uri"]) for artifact in run.get("artifacts", [])
]

for result_idx, result in enumerate(run.get("results", [])):
for location_idx, location in enumerate(result.get("locations", [])):
physical_location = location.get("physicalLocation", {})

artifact_location = physical_location.get("artifactLocation", {})
uri = resolve_artifact_location(base_path, artifact_location, artifact_locations)
if uri is None:
logger.warn(
f'Unable to find file for ".runs[{run_idx}].results[{result_idx}].locations[{location_idx}]"'
)
continue

region = physical_location.get("region", {})
start_line = region.get("startLine", 1)
end_line = region.get("endLine", start_line)
start_line = start_line - 1

# Generate file path assuming code is in the current working directory
file_path = str(uri.relative_to(base_path))

# Extract lines from the code file
logger.info(f"Extracting context for {file_path} at {start_line}:{end_line}")
try:
with open_with_chardet(file_path, "r") as file:
src = file.read()

source_lines = src.splitlines(keepends=True)
context_start, context_end = get_source_code_context(
file_path, source_lines, start_line, end_line, self.context_length
)

source_code_context = None
if context_start is not None and context_end is not None:
source_code_context = "".join(source_lines[context_start:context_end])

except FileNotFoundError:
context_start = None
context_end = None
source_code_context = None
logger.info(f"File not found in the current working directory: {file_path}")

if source_code_context is None:
logger.info(f"No context found for {file_path} at {start_line}:{end_line}")
continue

start = context_start if context_start is not None else start_line
end = context_end if context_end is not None else end_line
self.extracted_data.append(
{
"affectedCode": source_code_context,
"startLine": start,
"endLine": end,
"uri": file_path,
"messageText": result.get("message", {}).get("text", ""),
}
)

grouped_messages[(uri, start, end, source_code_context)].append(
result.get("message", {}).get("text", "")
)

vulnerability_count = vulnerability_count + 1
if 0 < self.vulnerability_limit <= vulnerability_count:
break
grouped_messages = transform_sarif_results(sarif_data, base_path, self.context_length, self.vulnerability_limit)

self.extracted_code_contexts = [
{
Expand Down
4 changes: 3 additions & 1 deletion patchwork/steps/ModifyCode/ModifyCode.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def run(self) -> dict:
code_snippets = load_json_file(self.code_snippets_path)

modified_code_files = []
sorted_list = sorted(zip(code_snippets, self.extracted_responses), key=lambda x: x[0]["startLine"], reverse=True)
sorted_list = sorted(
zip(code_snippets, self.extracted_responses), key=lambda x: x[0]["startLine"], reverse=True
)
for code_snippet, extracted_response in sorted_list:
uri = code_snippet["uri"]
start_line = code_snippet["startLine"]
Expand Down
22 changes: 22 additions & 0 deletions patchwork/steps/ReadPRDiffs/ReadPRDiffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@
from patchwork.logger import logger
from patchwork.step import Step

_IGNORED_EXTENSIONS = [
".png",
".jpg",
".jpeg",
".gif",
".svg",
".pdf",
".docx",
".xlsx",
".pptx",
".zip",
".tar",
".gz",
".lock",
]


def filter_by_extension(file, extensions):
return any(file.endswith(ext) for ext in extensions)


class ReadPRDiffs(Step):
required_keys = {"pr_url"}
Expand All @@ -30,6 +50,8 @@ def __init__(self, inputs: dict):
def run(self) -> dict:
prompt_values = []
for path, diffs in self.pr.file_diffs().items():
if filter_by_extension(path, _IGNORED_EXTENSIONS):
continue
prompt_values.append(dict(path=path, diff=diffs))

prompt_value_file = tempfile.mktemp(".json")
Expand Down
Loading

0 comments on commit 731cdad

Please sign in to comment.