Skip to content

Commit

Permalink
Add Command and SQL steps (#1181)
Browse files Browse the repository at this point in the history
* Add new steps

* update init file

* change command to shell

* bugfixes and changes

* lock update

* update callsql

* make username optional since some dbs take username in connect args instead

* fix test and filter url kwargs by none

* auto parse dict inputs

* bump version

* lint and update deps
  • Loading branch information
CTY-git authored Jan 21, 2025
1 parent 5a7c523 commit a74d35c
Show file tree
Hide file tree
Showing 18 changed files with 1,041 additions and 614 deletions.
21 changes: 21 additions & 0 deletions patchwork/common/utils/input_parsing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
from collections.abc import Iterable, Mapping

from typing_extensions import AnyStr, Union
Expand Down Expand Up @@ -69,3 +70,23 @@ def parse_to_list(
continue
rv.append(stripped_value)
return rv


def parse_to_dict(possible_dict, limit=-1):
if possible_dict is None and limit == 0:
return None

if isinstance(possible_dict, dict):
new_dict = dict()
for k, v in possible_dict.items():
new_dict[k] = parse_to_dict(v, limit - 1)
return new_dict
elif isinstance(possible_dict, str):
try:
new_dict = json.loads(possible_dict, strict=False)
except json.JSONDecodeError:
return possible_dict

return parse_to_dict(new_dict, limit - 1)
else:
return possible_dict
18 changes: 18 additions & 0 deletions patchwork/common/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

import atexit
import dataclasses
import random
import signal
import string
import tempfile
from collections.abc import Mapping
from pathlib import Path

import chevron
import tiktoken
from chardet.universaldetector import UniversalDetector
from git import Head, Repo
Expand All @@ -19,6 +23,20 @@
_NEWLINES = {"\n", "\r\n", "\r"}


def mustache_render(template: str, data: Mapping) -> str:
if len(data.keys()) < 1:
return template

chevron.render.__globals__["_html_escape"] = lambda x: x
return chevron.render(
template=template,
data=data,
partials_path=None,
partials_ext="".join(random.choices(string.ascii_uppercase + string.digits, k=32)),
partials_dict=dict(),
)


def detect_newline(path: str | Path) -> str | None:
with open(path, "r", newline="") as f:
lines = f.read().splitlines(keepends=True)
Expand Down
46 changes: 30 additions & 16 deletions patchwork/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@

from enum import Enum

from typing_extensions import Any, Dict, List, Optional, Union, is_typeddict
from typing_extensions import (
Any,
Collection,
Dict,
List,
Optional,
Type,
Union,
is_typeddict,
)

from patchwork.logger import logger

Expand Down Expand Up @@ -45,10 +54,9 @@ def __init__(self, inputs: DataPoint):
"""

# check if the inputs have the required keys
if self.__input_class is not None:
missing_keys = self.__input_class.__required_keys__.difference(inputs.keys())
if len(missing_keys) > 0:
raise ValueError(f"Missing required data: {list(missing_keys)}")
missing_keys = self.find_missing_inputs(inputs)
if len(missing_keys) > 0:
raise ValueError(f"Missing required data: {list(missing_keys)}")

# store the inputs
self.inputs = inputs
Expand All @@ -64,19 +72,25 @@ def __init__(self, inputs: DataPoint):
self.original_run = self.run
self.run = self.__managed_run

def __init_subclass__(cls, **kwargs):
input_class = kwargs.get("input_class", None) or getattr(cls, "input_class", None)
output_class = kwargs.get("output_class", None) or getattr(cls, "output_class", None)
def __init_subclass__(cls, input_class: Optional[Type] = None, output_class: Optional[Type] = None, **kwargs):
if cls.__name__ == "PreparePR":
print(1)
input_class = input_class or getattr(cls, "input_class", None)
if input_class is not None and not is_typeddict(input_class):
input_class = None

if input_class is not None and is_typeddict(input_class):
cls.__input_class = input_class
else:
cls.__input_class = None
output_class = output_class or getattr(cls, "output_class", None)
if output_class is not None and not is_typeddict(output_class):
output_class = None

if output_class is not None and is_typeddict(output_class):
cls.__output_class = output_class
else:
cls.__output_class = None
cls._input_class = input_class
cls._output_class = output_class

@classmethod
def find_missing_inputs(cls, inputs: DataPoint) -> Collection:
if getattr(cls, "_input_class", None) is None:
return []
return cls._input_class.__required_keys__.difference(inputs.keys())

def __managed_run(self, *args, **kwargs) -> Any:
self.debug(self.inputs)
Expand Down
57 changes: 57 additions & 0 deletions patchwork/steps/CallSQL/CallSQL.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

from sqlalchemy import URL, create_engine, exc, text

from patchwork.common.utils.input_parsing import parse_to_dict
from patchwork.common.utils.utils import mustache_render
from patchwork.logger import logger
from patchwork.step import Step, StepStatus
from patchwork.steps.CallSQL.typed import CallSQLInputs, CallSQLOutputs


class CallSQL(Step, input_class=CallSQLInputs, output_class=CallSQLOutputs):
def __init__(self, inputs: dict):
super().__init__(inputs)
query_template_data = inputs.get("db_query_template_values", {})
self.query = mustache_render(inputs["db_query"], query_template_data)
self.__build_engine(inputs)

def __build_engine(self, inputs: dict):
dialect = inputs["db_dialect"]
driver = inputs.get("db_driver")
dialect_plus_driver = f"{dialect}+{driver}" if driver is not None else dialect
kwargs = dict(
username=inputs.get("db_username"),
host=inputs.get("db_host", "localhost"),
port=inputs.get("db_port", 5432),
password=inputs.get("db_password"),
database=inputs.get("db_database"),
query=parse_to_dict(inputs.get("db_params")),
)
connection_url = URL.create(
dialect_plus_driver,
**{k: v for k, v in kwargs.items() if v is not None},
)

connect_args = None
if inputs.get("db_driver_args") is not None:
connect_args = parse_to_dict(inputs.get("db_driver_args"))

self.engine = create_engine(connection_url, connect_args=connect_args)
with self.engine.connect() as conn:
conn.execute(text("SELECT 1"))
return self.engine

def run(self) -> dict:
try:
rv = []
with self.engine.begin() as conn:
cursor = conn.execute(text(self.query))
for row in cursor:
result = row._asdict()
rv.append(result)
logger.info(f"Retrieved {len(rv)} rows!")
return dict(results=rv)
except exc.InvalidRequestError as e:
self.set_status(StepStatus.FAILED, f"`{self.query}` failed with message:\n{e}")
return dict(results=[])
Empty file.
24 changes: 24 additions & 0 deletions patchwork/steps/CallSQL/typed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from typing_extensions import Any, TypedDict


class __RequiredCallSQLInputs(TypedDict):
db_dialect: str
db_query: str


class CallSQLInputs(__RequiredCallSQLInputs, total=False):
db_driver: str
db_username: str
db_password: str
db_host: str
db_port: int
db_name: str
db_params: dict[str, Any]
db_driver_args: dict[str, Any]
db_query_template_values: dict[str, Any]


class CallSQLOutputs(TypedDict):
results: list[dict[str, Any]]
58 changes: 58 additions & 0 deletions patchwork/steps/CallShell/CallShell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

import shlex
import subprocess
from pathlib import Path

from patchwork.common.utils.utils import mustache_render
from patchwork.logger import logger
from patchwork.step import Step, StepStatus
from patchwork.steps.CallShell.typed import CallShellInputs, CallShellOutputs


class CallShell(Step, input_class=CallShellInputs, output_class=CallShellOutputs):
def __init__(self, inputs: dict):
super().__init__(inputs)
script_template_values = inputs.get("script_template_values", {})
self.script = mustache_render(inputs["script"], script_template_values)
self.working_dir = inputs.get("working_dir", Path.cwd())
self.env = self.__parse_env_text(inputs.get("env", ""))

@staticmethod
def __parse_env_text(env_text: str) -> dict[str, str]:
env_spliter = shlex.shlex(env_text, posix=True)
env_spliter.whitespace_split = True
env_spliter.whitespace += ";"

env: dict[str, str] = dict()
for env_assign in env_spliter:
env_assign_spliter = shlex.shlex(env_assign, posix=True)
env_assign_spliter.whitespace_split = True
env_assign_spliter.whitespace += "="
env_parts = list(env_assign_spliter)
if len(env_parts) < 1:
continue

env_assign_target = env_parts[0]
if len(env_parts) < 2:
logger.error(f"{env_assign_target} is not assigned anything, skipping...")
continue
if len(env_parts) > 2:
logger.error(f"{env_assign_target} has more than 1 assignment, skipping...")
continue
env[env_assign_target] = env_parts[1]

return env

def run(self) -> dict:
p = subprocess.run(self.script, shell=True, capture_output=True, text=True, cwd=self.working_dir, env=self.env)
try:
p.check_returncode()
except subprocess.CalledProcessError as e:
self.set_status(
StepStatus.FAILED,
f"Script failed.",
)
logger.info(f"stdout: \n{p.stdout}")
logger.info(f"stderr:\n{p.stderr}")
return dict(stdout_output=p.stdout, stderr_output=p.stderr)
Empty file.
19 changes: 19 additions & 0 deletions patchwork/steps/CallShell/typed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from typing_extensions import Annotated, Any, TypedDict

from patchwork.common.utils.step_typing import StepTypeConfig


class __RequiredCallShellInputs(TypedDict):
script: str


class CallShellInputs(__RequiredCallShellInputs, total=False):
working_dir: Annotated[str, StepTypeConfig(is_path=True)]
env: str
script_template_values: dict[str, Any]


class CallShellOutputs(TypedDict):
stdout_output: str
16 changes: 7 additions & 9 deletions patchwork/steps/FixIssue/FixIssue.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import difflib
import re
from pathlib import Path
from typing import Any, Optional

from git import Repo, InvalidGitRepositoryError
from patchwork.logger import logger
from git import InvalidGitRepositoryError, Repo
from openai.types.chat import ChatCompletionMessageParam

from patchwork.common.client.llm.aio import AioLlmClient
Expand All @@ -15,6 +13,7 @@
AnalyzeImplementStrategy,
)
from patchwork.common.tools import CodeEditTool, Tool
from patchwork.logger import logger
from patchwork.step import Step
from patchwork.steps.FixIssue.typed import FixIssueInputs, FixIssueOutputs

Expand Down Expand Up @@ -100,7 +99,7 @@ def is_stop(self, messages: list[ChatCompletionMessageParam]) -> bool:
class FixIssue(Step, input_class=FixIssueInputs, output_class=FixIssueOutputs):
def __init__(self, inputs):
"""Initialize the FixIssue step.
Args:
inputs: Dictionary containing input parameters including:
- base_path: Optional path to the repository root
Expand Down Expand Up @@ -145,12 +144,12 @@ def __init__(self, inputs):

def run(self):
"""Execute the FixIssue step.
This method:
1. Executes the multi-turn LLM conversation to analyze and fix the issue
2. Tracks file modifications made by the CodeEditTool
3. Generates in-memory diffs for all modified files
Returns:
dict: Dictionary containing list of modified files with their diffs
"""
Expand All @@ -162,8 +161,7 @@ def run(self):
if not isinstance(tool, CodeEditTool):
continue
tool_modified_files = [
dict(path=str(file_path.relative_to(cwd)), diff="")
for file_path in tool.tool_records["modified_files"]
dict(path=str(file_path.relative_to(cwd)), diff="") for file_path in tool.tool_records["modified_files"]
]
modified_files.extend(tool_modified_files)

Expand All @@ -174,7 +172,7 @@ def run(self):
file = modified_file["path"]
try:
# Try to get the diff using git
diff = self.repo.git.diff('HEAD', file)
diff = self.repo.git.diff("HEAD", file)
modified_file["diff"] = diff or ""
except Exception as e:
# Git-specific errors (untracked files, etc) - keep empty diff
Expand Down
8 changes: 5 additions & 3 deletions patchwork/steps/FixIssue/typed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing_extensions import Annotated, Dict, List, TypedDict
from typing_extensions import Annotated, List, TypedDict

from patchwork.common.constants import TOKEN_URL
from patchwork.common.utils.step_typing import StepTypeConfig
Expand Down Expand Up @@ -37,19 +37,21 @@ class FixIssueInputs(__FixIssueRequiredInputs, total=False):

class ModifiedFile(TypedDict):
"""Represents a file that has been modified by the FixIssue step.
Attributes:
path: The relative path to the modified file from the repository root
diff: A unified diff string showing the changes made to the file.
Generated using Python's difflib to compare the original and
modified file contents in memory.
Note:
The diff is generated by comparing file contents before and after
modifications, without relying on version control systems.
"""

path: str
diff: str


class FixIssueOutputs(TypedDict):
modified_files: List[ModifiedFile]
Loading

0 comments on commit a74d35c

Please sign in to comment.