Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PATCHED_API_KEY integration #12

Merged
merged 8 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions patchwork/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
import json
import traceback
from pathlib import Path

import click
Expand Down Expand Up @@ -73,9 +74,13 @@ def cli(log: str, patchflow: str, opts: list[str], config: str | None, output: s
else:
# treat --key=value as a key-value pair
inputs[key] = value

patchflow_instance = patchflow_class(inputs)
patchflow_instance.run()
try:
patchflow_instance = patchflow_class(inputs)
patchflow_instance.run()
except Exception as e:
logger.debug(traceback.format_exc())
logger.error(f"Error running patchflow {patchflow}: {e}")
exit(1)

data_format_mapping = {
"yaml": yaml.dump,
Expand Down
23 changes: 16 additions & 7 deletions patchwork/common/client/scm.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,13 @@ def set_pr_description(self, body: str) -> None:
self._mr.save()

def create_comment(
self, path: str, body: str, start_line: int | None = None, end_line: int | None = None
self, body: str, path: str | None = None, start_line: int | None = None, end_line: int | None = None
) -> str | None:
final_body = f"{_COMMENT_MARKER} \n{PullRequestProtocol._apply_pr_template(self, body)}"
if path is None:
note = self._mr.notes.create({"body": final_body})
return f"#note_{note.get_id()}"

while True:
try:
commit = self._mr.commits().next()
Expand All @@ -161,7 +166,6 @@ def create_comment(
head_commit = diff.head_commit_sha

try:
final_body = f"{_COMMENT_MARKER} \n{PullRequestProtocol._apply_pr_template(self, body)}"
discussion = self._mr.discussions.create(
{
"body": final_body,
Expand All @@ -187,14 +191,19 @@ def create_comment(
return None

def reset_comments(self) -> None:
for discussion in self._mr.discussions.list():
for discussion in self._mr.discussions.list(iterator=True):
for note in discussion.attributes["notes"]:
if note["type"] == "DiffNote" and note["body"].startswith(_COMMENT_MARKER):
if note["body"].startswith(_COMMENT_MARKER):
discussion.notes.delete(note["id"])

def file_diffs(self) -> dict[str, str]:
files = self._mr.diffs.list()
return {file.attributes["new_path"]: file.attributes["diff"] for file in files}
diffs = self._mr.diffs.list()
latest_diff = max(diffs, key=lambda diff: diff.created_at, default=None)
if latest_diff is None:
return {}

files = self._mr.diffs.get(latest_diff.id).diffs
return {file["new_path"]: file["diff"] for file in files}


class GithubPullRequest(PullRequestProtocol):
Expand Down Expand Up @@ -336,7 +345,7 @@ def get_pr_by_url(self, url: str) -> PullRequestProtocol | None:
logger.error(f"Invalid PR URL: {url}")
return None

slug = "/".join(url_parts[-4:-2])
slug = "/".join(url_parts[-5:-3])

return self.find_pr_by_id(slug, int(pr_id))

Expand Down
29 changes: 27 additions & 2 deletions patchwork/steps/CallOpenAI/CallOpenAI.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from patchwork.logger import logger
from patchwork.step import Step

_TOKEN_URL = "https://app.patched.codes/signin"
_DEFAULT_PATCH_URL = "https://patchwork.patched.codes/v1"


class CallOpenAI(Step):
required_keys = {"openai_api_key", "prompt_file"}
required_keys = {"prompt_file"}

def __init__(self, inputs: dict):
logger.info(f"Run started {self.__class__.__name__}")
Expand All @@ -27,7 +30,29 @@ def __init__(self, inputs: dict):
self.model_args = {key[len("model_") :]: value for key, value in inputs.items() if key.startswith("model_")}
self.client_args = {key[len("client_") :]: value for key, value in inputs.items() if key.startswith("client_")}

self.openai_api_key = inputs["openai_api_key"]
openai_key = inputs.get("openai_api_key") or os.environ.get("OPENAI_API_KEY")
if openai_key is not None:
self.openai_api_key = openai_key

patched_key = inputs.get("patched_api_key")
if patched_key is not None:
self.openai_api_key = patched_key
self.client_args["base_url"] = _DEFAULT_PATCH_URL

if self.openai_api_key is None:
raise ValueError(
f"Model API key not found.\n"
f'Please login at: "{_TOKEN_URL}",\n'
"Please go to the Integration's tab and generate an API key.\n"
"Please copy the access token that is generated, "
"and add `--patched_api_key=<token>` to the command line.\n"
"\n"
"If you are using a OpenAI API Key, please set `--openai_api_key=<token>`.\n"
)

if not self.openai_api_key:
raise ValueError('Missing required data: "openai_api_key"')

self.prompt_file = Path(inputs["prompt_file"])
if not self.prompt_file.is_file():
raise ValueError(f'Unable to find Prompt file: "{self.prompt_file}"')
Expand Down
142 changes: 71 additions & 71 deletions patchwork/steps/ExtractCode/ExtractCode.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,76 @@ def resolve_artifact_location(
return None


def transform_sarif_results(
sarif_data: dict, base_path: Path, context_length: int, vulnerability_limit: int
) -> dict[tuple[str, int, int, int], list[str]]:
# Process each result in SARIF data
grouped_messages = defaultdict(list)
vulnerability_count = 0
for run_idx, run in enumerate(sarif_data.get("runs", [])):
artifact_locations = [
parse_sarif_location(base_path, artifact["location"]["uri"]) for artifact in run.get("artifacts", [])
]

for result_idx, result in enumerate(run.get("results", [])):
for location_idx, location in enumerate(result.get("locations", [])):
physical_location = location.get("physicalLocation", {})

artifact_location = physical_location.get("artifactLocation", {})
uri = resolve_artifact_location(base_path, artifact_location, artifact_locations)
if uri is None:
logger.warn(
f'Unable to find file for ".runs[{run_idx}].results[{result_idx}].locations[{location_idx}]"'
)
continue

region = physical_location.get("region", {})
start_line = region.get("startLine", 1)
end_line = region.get("endLine", start_line)
start_line = start_line - 1

# Generate file path assuming code is in the current working directory
file_path = str(uri.relative_to(base_path))

# Extract lines from the code file
logger.info(f"Extracting context for {file_path} at {start_line}:{end_line}")
try:
with open_with_chardet(file_path, "r") as file:
src = file.read()

source_lines = src.splitlines(keepends=True)
context_start, context_end = get_source_code_context(
file_path, source_lines, start_line, end_line, context_length
)

source_code_context = None
if context_start is not None and context_end is not None:
source_code_context = "".join(source_lines[context_start:context_end])

except FileNotFoundError:
context_start = None
context_end = None
source_code_context = None
logger.info(f"File not found in the current working directory: {file_path}")

if source_code_context is None:
logger.info(f"No context found for {file_path} at {start_line}:{end_line}")
continue

start = context_start if context_start is not None else start_line
end = context_end if context_end is not None else end_line

grouped_messages[(uri, start, end, source_code_context)].append(
result.get("message", {}).get("text", "")
)

vulnerability_count = vulnerability_count + 1
if 0 < vulnerability_limit <= vulnerability_count:
return grouped_messages

return grouped_messages


class ExtractCode(Step):
required_keys = {"sarif_file_path"}

Expand All @@ -112,7 +182,6 @@ def __init__(self, inputs: dict):
self.vulnerability_limit = inputs.get("vulnerability_limit", 10)

# Prepare for data extraction
self.extracted_data = []
self.extracted_code_contexts = []

def run(self) -> dict:
Expand All @@ -122,77 +191,8 @@ def run(self) -> dict:

vulnerability_count = 0
base_path = Path.cwd()
# Process each result in SARIF data
grouped_messages = defaultdict(list)
for run_idx, run in enumerate(sarif_data.get("runs", [])):
artifact_locations = [
parse_sarif_location(base_path, artifact["location"]["uri"]) for artifact in run.get("artifacts", [])
]

for result_idx, result in enumerate(run.get("results", [])):
for location_idx, location in enumerate(result.get("locations", [])):
physical_location = location.get("physicalLocation", {})

artifact_location = physical_location.get("artifactLocation", {})
uri = resolve_artifact_location(base_path, artifact_location, artifact_locations)
if uri is None:
logger.warn(
f'Unable to find file for ".runs[{run_idx}].results[{result_idx}].locations[{location_idx}]"'
)
continue

region = physical_location.get("region", {})
start_line = region.get("startLine", 1)
end_line = region.get("endLine", start_line)
start_line = start_line - 1

# Generate file path assuming code is in the current working directory
file_path = str(uri.relative_to(base_path))

# Extract lines from the code file
logger.info(f"Extracting context for {file_path} at {start_line}:{end_line}")
try:
with open_with_chardet(file_path, "r") as file:
src = file.read()

source_lines = src.splitlines(keepends=True)
context_start, context_end = get_source_code_context(
file_path, source_lines, start_line, end_line, self.context_length
)

source_code_context = None
if context_start is not None and context_end is not None:
source_code_context = "".join(source_lines[context_start:context_end])

except FileNotFoundError:
context_start = None
context_end = None
source_code_context = None
logger.info(f"File not found in the current working directory: {file_path}")

if source_code_context is None:
logger.info(f"No context found for {file_path} at {start_line}:{end_line}")
continue

start = context_start if context_start is not None else start_line
end = context_end if context_end is not None else end_line
self.extracted_data.append(
{
"affectedCode": source_code_context,
"startLine": start,
"endLine": end,
"uri": file_path,
"messageText": result.get("message", {}).get("text", ""),
}
)

grouped_messages[(uri, start, end, source_code_context)].append(
result.get("message", {}).get("text", "")
)

vulnerability_count = vulnerability_count + 1
if 0 < self.vulnerability_limit <= vulnerability_count:
break
grouped_messages = transform_sarif_results(sarif_data, base_path, self.context_length, self.vulnerability_limit)

self.extracted_code_contexts = [
{
Expand Down
4 changes: 3 additions & 1 deletion patchwork/steps/ModifyCode/ModifyCode.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def run(self) -> dict:
code_snippets = load_json_file(self.code_snippets_path)

modified_code_files = []
sorted_list = sorted(zip(code_snippets, self.extracted_responses), key=lambda x: x[0]["startLine"], reverse=True)
sorted_list = sorted(
zip(code_snippets, self.extracted_responses), key=lambda x: x[0]["startLine"], reverse=True
)
for code_snippet, extracted_response in sorted_list:
uri = code_snippet["uri"]
start_line = code_snippet["startLine"]
Expand Down
22 changes: 22 additions & 0 deletions patchwork/steps/ReadPRDiffs/ReadPRDiffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@
from patchwork.logger import logger
from patchwork.step import Step

_IGNORED_EXTENSIONS = [
".png",
".jpg",
".jpeg",
".gif",
".svg",
".pdf",
".docx",
".xlsx",
".pptx",
".zip",
".tar",
".gz",
".lock",
]


def filter_by_extension(file, extensions):
return any(file.endswith(ext) for ext in extensions)


class ReadPRDiffs(Step):
required_keys = {"pr_url"}
Expand All @@ -30,6 +50,8 @@ def __init__(self, inputs: dict):
def run(self) -> dict:
prompt_values = []
for path, diffs in self.pr.file_diffs().items():
if filter_by_extension(path, _IGNORED_EXTENSIONS):
continue
prompt_values.append(dict(path=path, diff=diffs))

prompt_value_file = tempfile.mktemp(".json")
Expand Down
Loading
Loading