Skip to content

Commit a74d35c

Browse files
authored
Add Command and SQL steps (#1181)
* Add new steps * update init file * change command to shell * bugfixes and changes * lock update * update callsql * make username optional since some dbs take username in connect args instead * fix test and filter url kwargs by none * auto parse dict inputs * bump version * lint and update deps
1 parent 5a7c523 commit a74d35c

File tree

18 files changed

+1041
-614
lines changed

18 files changed

+1041
-614
lines changed

patchwork/common/utils/input_parsing.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import json
34
from collections.abc import Iterable, Mapping
45

56
from typing_extensions import AnyStr, Union
@@ -69,3 +70,23 @@ def parse_to_list(
6970
continue
7071
rv.append(stripped_value)
7172
return rv
73+
74+
75+
def parse_to_dict(possible_dict, limit=-1):
76+
if possible_dict is None and limit == 0:
77+
return None
78+
79+
if isinstance(possible_dict, dict):
80+
new_dict = dict()
81+
for k, v in possible_dict.items():
82+
new_dict[k] = parse_to_dict(v, limit - 1)
83+
return new_dict
84+
elif isinstance(possible_dict, str):
85+
try:
86+
new_dict = json.loads(possible_dict, strict=False)
87+
except json.JSONDecodeError:
88+
return possible_dict
89+
90+
return parse_to_dict(new_dict, limit - 1)
91+
else:
92+
return possible_dict

patchwork/common/utils/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22

33
import atexit
44
import dataclasses
5+
import random
56
import signal
7+
import string
68
import tempfile
9+
from collections.abc import Mapping
710
from pathlib import Path
811

12+
import chevron
913
import tiktoken
1014
from chardet.universaldetector import UniversalDetector
1115
from git import Head, Repo
@@ -19,6 +23,20 @@
1923
_NEWLINES = {"\n", "\r\n", "\r"}
2024

2125

26+
def mustache_render(template: str, data: Mapping) -> str:
27+
if len(data.keys()) < 1:
28+
return template
29+
30+
chevron.render.__globals__["_html_escape"] = lambda x: x
31+
return chevron.render(
32+
template=template,
33+
data=data,
34+
partials_path=None,
35+
partials_ext="".join(random.choices(string.ascii_uppercase + string.digits, k=32)),
36+
partials_dict=dict(),
37+
)
38+
39+
2240
def detect_newline(path: str | Path) -> str | None:
2341
with open(path, "r", newline="") as f:
2442
lines = f.read().splitlines(keepends=True)

patchwork/step.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,16 @@
1111

1212
from enum import Enum
1313

14-
from typing_extensions import Any, Dict, List, Optional, Union, is_typeddict
14+
from typing_extensions import (
15+
Any,
16+
Collection,
17+
Dict,
18+
List,
19+
Optional,
20+
Type,
21+
Union,
22+
is_typeddict,
23+
)
1524

1625
from patchwork.logger import logger
1726

@@ -45,10 +54,9 @@ def __init__(self, inputs: DataPoint):
4554
"""
4655

4756
# check if the inputs have the required keys
48-
if self.__input_class is not None:
49-
missing_keys = self.__input_class.__required_keys__.difference(inputs.keys())
50-
if len(missing_keys) > 0:
51-
raise ValueError(f"Missing required data: {list(missing_keys)}")
57+
missing_keys = self.find_missing_inputs(inputs)
58+
if len(missing_keys) > 0:
59+
raise ValueError(f"Missing required data: {list(missing_keys)}")
5260

5361
# store the inputs
5462
self.inputs = inputs
@@ -64,19 +72,25 @@ def __init__(self, inputs: DataPoint):
6472
self.original_run = self.run
6573
self.run = self.__managed_run
6674

67-
def __init_subclass__(cls, **kwargs):
68-
input_class = kwargs.get("input_class", None) or getattr(cls, "input_class", None)
69-
output_class = kwargs.get("output_class", None) or getattr(cls, "output_class", None)
75+
def __init_subclass__(cls, input_class: Optional[Type] = None, output_class: Optional[Type] = None, **kwargs):
76+
if cls.__name__ == "PreparePR":
77+
print(1)
78+
input_class = input_class or getattr(cls, "input_class", None)
79+
if input_class is not None and not is_typeddict(input_class):
80+
input_class = None
7081

71-
if input_class is not None and is_typeddict(input_class):
72-
cls.__input_class = input_class
73-
else:
74-
cls.__input_class = None
82+
output_class = output_class or getattr(cls, "output_class", None)
83+
if output_class is not None and not is_typeddict(output_class):
84+
output_class = None
7585

76-
if output_class is not None and is_typeddict(output_class):
77-
cls.__output_class = output_class
78-
else:
79-
cls.__output_class = None
86+
cls._input_class = input_class
87+
cls._output_class = output_class
88+
89+
@classmethod
90+
def find_missing_inputs(cls, inputs: DataPoint) -> Collection:
91+
if getattr(cls, "_input_class", None) is None:
92+
return []
93+
return cls._input_class.__required_keys__.difference(inputs.keys())
8094

8195
def __managed_run(self, *args, **kwargs) -> Any:
8296
self.debug(self.inputs)

patchwork/steps/CallSQL/CallSQL.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from __future__ import annotations
2+
3+
from sqlalchemy import URL, create_engine, exc, text
4+
5+
from patchwork.common.utils.input_parsing import parse_to_dict
6+
from patchwork.common.utils.utils import mustache_render
7+
from patchwork.logger import logger
8+
from patchwork.step import Step, StepStatus
9+
from patchwork.steps.CallSQL.typed import CallSQLInputs, CallSQLOutputs
10+
11+
12+
class CallSQL(Step, input_class=CallSQLInputs, output_class=CallSQLOutputs):
13+
def __init__(self, inputs: dict):
14+
super().__init__(inputs)
15+
query_template_data = inputs.get("db_query_template_values", {})
16+
self.query = mustache_render(inputs["db_query"], query_template_data)
17+
self.__build_engine(inputs)
18+
19+
def __build_engine(self, inputs: dict):
20+
dialect = inputs["db_dialect"]
21+
driver = inputs.get("db_driver")
22+
dialect_plus_driver = f"{dialect}+{driver}" if driver is not None else dialect
23+
kwargs = dict(
24+
username=inputs.get("db_username"),
25+
host=inputs.get("db_host", "localhost"),
26+
port=inputs.get("db_port", 5432),
27+
password=inputs.get("db_password"),
28+
database=inputs.get("db_database"),
29+
query=parse_to_dict(inputs.get("db_params")),
30+
)
31+
connection_url = URL.create(
32+
dialect_plus_driver,
33+
**{k: v for k, v in kwargs.items() if v is not None},
34+
)
35+
36+
connect_args = None
37+
if inputs.get("db_driver_args") is not None:
38+
connect_args = parse_to_dict(inputs.get("db_driver_args"))
39+
40+
self.engine = create_engine(connection_url, connect_args=connect_args)
41+
with self.engine.connect() as conn:
42+
conn.execute(text("SELECT 1"))
43+
return self.engine
44+
45+
def run(self) -> dict:
46+
try:
47+
rv = []
48+
with self.engine.begin() as conn:
49+
cursor = conn.execute(text(self.query))
50+
for row in cursor:
51+
result = row._asdict()
52+
rv.append(result)
53+
logger.info(f"Retrieved {len(rv)} rows!")
54+
return dict(results=rv)
55+
except exc.InvalidRequestError as e:
56+
self.set_status(StepStatus.FAILED, f"`{self.query}` failed with message:\n{e}")
57+
return dict(results=[])

patchwork/steps/CallSQL/__init__.py

Whitespace-only changes.

patchwork/steps/CallSQL/typed.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
3+
from typing_extensions import Any, TypedDict
4+
5+
6+
class __RequiredCallSQLInputs(TypedDict):
7+
db_dialect: str
8+
db_query: str
9+
10+
11+
class CallSQLInputs(__RequiredCallSQLInputs, total=False):
12+
db_driver: str
13+
db_username: str
14+
db_password: str
15+
db_host: str
16+
db_port: int
17+
db_name: str
18+
db_params: dict[str, Any]
19+
db_driver_args: dict[str, Any]
20+
db_query_template_values: dict[str, Any]
21+
22+
23+
class CallSQLOutputs(TypedDict):
24+
results: list[dict[str, Any]]
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from __future__ import annotations
2+
3+
import shlex
4+
import subprocess
5+
from pathlib import Path
6+
7+
from patchwork.common.utils.utils import mustache_render
8+
from patchwork.logger import logger
9+
from patchwork.step import Step, StepStatus
10+
from patchwork.steps.CallShell.typed import CallShellInputs, CallShellOutputs
11+
12+
13+
class CallShell(Step, input_class=CallShellInputs, output_class=CallShellOutputs):
14+
def __init__(self, inputs: dict):
15+
super().__init__(inputs)
16+
script_template_values = inputs.get("script_template_values", {})
17+
self.script = mustache_render(inputs["script"], script_template_values)
18+
self.working_dir = inputs.get("working_dir", Path.cwd())
19+
self.env = self.__parse_env_text(inputs.get("env", ""))
20+
21+
@staticmethod
22+
def __parse_env_text(env_text: str) -> dict[str, str]:
23+
env_spliter = shlex.shlex(env_text, posix=True)
24+
env_spliter.whitespace_split = True
25+
env_spliter.whitespace += ";"
26+
27+
env: dict[str, str] = dict()
28+
for env_assign in env_spliter:
29+
env_assign_spliter = shlex.shlex(env_assign, posix=True)
30+
env_assign_spliter.whitespace_split = True
31+
env_assign_spliter.whitespace += "="
32+
env_parts = list(env_assign_spliter)
33+
if len(env_parts) < 1:
34+
continue
35+
36+
env_assign_target = env_parts[0]
37+
if len(env_parts) < 2:
38+
logger.error(f"{env_assign_target} is not assigned anything, skipping...")
39+
continue
40+
if len(env_parts) > 2:
41+
logger.error(f"{env_assign_target} has more than 1 assignment, skipping...")
42+
continue
43+
env[env_assign_target] = env_parts[1]
44+
45+
return env
46+
47+
def run(self) -> dict:
48+
p = subprocess.run(self.script, shell=True, capture_output=True, text=True, cwd=self.working_dir, env=self.env)
49+
try:
50+
p.check_returncode()
51+
except subprocess.CalledProcessError as e:
52+
self.set_status(
53+
StepStatus.FAILED,
54+
f"Script failed.",
55+
)
56+
logger.info(f"stdout: \n{p.stdout}")
57+
logger.info(f"stderr:\n{p.stderr}")
58+
return dict(stdout_output=p.stdout, stderr_output=p.stderr)

patchwork/steps/CallShell/__init__.py

Whitespace-only changes.

patchwork/steps/CallShell/typed.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from __future__ import annotations
2+
3+
from typing_extensions import Annotated, Any, TypedDict
4+
5+
from patchwork.common.utils.step_typing import StepTypeConfig
6+
7+
8+
class __RequiredCallShellInputs(TypedDict):
9+
script: str
10+
11+
12+
class CallShellInputs(__RequiredCallShellInputs, total=False):
13+
working_dir: Annotated[str, StepTypeConfig(is_path=True)]
14+
env: str
15+
script_template_values: dict[str, Any]
16+
17+
18+
class CallShellOutputs(TypedDict):
19+
stdout_output: str

patchwork/steps/FixIssue/FixIssue.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import difflib
21
import re
32
from pathlib import Path
43
from typing import Any, Optional
54

6-
from git import Repo, InvalidGitRepositoryError
7-
from patchwork.logger import logger
5+
from git import InvalidGitRepositoryError, Repo
86
from openai.types.chat import ChatCompletionMessageParam
97

108
from patchwork.common.client.llm.aio import AioLlmClient
@@ -15,6 +13,7 @@
1513
AnalyzeImplementStrategy,
1614
)
1715
from patchwork.common.tools import CodeEditTool, Tool
16+
from patchwork.logger import logger
1817
from patchwork.step import Step
1918
from patchwork.steps.FixIssue.typed import FixIssueInputs, FixIssueOutputs
2019

@@ -100,7 +99,7 @@ def is_stop(self, messages: list[ChatCompletionMessageParam]) -> bool:
10099
class FixIssue(Step, input_class=FixIssueInputs, output_class=FixIssueOutputs):
101100
def __init__(self, inputs):
102101
"""Initialize the FixIssue step.
103-
102+
104103
Args:
105104
inputs: Dictionary containing input parameters including:
106105
- base_path: Optional path to the repository root
@@ -145,12 +144,12 @@ def __init__(self, inputs):
145144

146145
def run(self):
147146
"""Execute the FixIssue step.
148-
147+
149148
This method:
150149
1. Executes the multi-turn LLM conversation to analyze and fix the issue
151150
2. Tracks file modifications made by the CodeEditTool
152151
3. Generates in-memory diffs for all modified files
153-
152+
154153
Returns:
155154
dict: Dictionary containing list of modified files with their diffs
156155
"""
@@ -162,8 +161,7 @@ def run(self):
162161
if not isinstance(tool, CodeEditTool):
163162
continue
164163
tool_modified_files = [
165-
dict(path=str(file_path.relative_to(cwd)), diff="")
166-
for file_path in tool.tool_records["modified_files"]
164+
dict(path=str(file_path.relative_to(cwd)), diff="") for file_path in tool.tool_records["modified_files"]
167165
]
168166
modified_files.extend(tool_modified_files)
169167

@@ -174,7 +172,7 @@ def run(self):
174172
file = modified_file["path"]
175173
try:
176174
# Try to get the diff using git
177-
diff = self.repo.git.diff('HEAD', file)
175+
diff = self.repo.git.diff("HEAD", file)
178176
modified_file["diff"] = diff or ""
179177
except Exception as e:
180178
# Git-specific errors (untracked files, etc) - keep empty diff

patchwork/steps/FixIssue/typed.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing_extensions import Annotated, Dict, List, TypedDict
1+
from typing_extensions import Annotated, List, TypedDict
22

33
from patchwork.common.constants import TOKEN_URL
44
from patchwork.common.utils.step_typing import StepTypeConfig
@@ -37,19 +37,21 @@ class FixIssueInputs(__FixIssueRequiredInputs, total=False):
3737

3838
class ModifiedFile(TypedDict):
3939
"""Represents a file that has been modified by the FixIssue step.
40-
40+
4141
Attributes:
4242
path: The relative path to the modified file from the repository root
4343
diff: A unified diff string showing the changes made to the file.
4444
Generated using Python's difflib to compare the original and
4545
modified file contents in memory.
46-
46+
4747
Note:
4848
The diff is generated by comparing file contents before and after
4949
modifications, without relying on version control systems.
5050
"""
51+
5152
path: str
5253
diff: str
5354

55+
5456
class FixIssueOutputs(TypedDict):
5557
modified_files: List[ModifiedFile]

0 commit comments

Comments
 (0)