Skip to content

Commit 93a5600

Browse files
authored
Merge pull request #585 from stacklok/issue-580
fix: use the latest user messages block instead of single message
2 parents 7593f72 + 2397720 commit 93a5600

File tree

9 files changed

+265
-80
lines changed

9 files changed

+265
-80
lines changed

Diff for: data/archived.jsonl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{"name":"@prefix/archived-npm-dummy","type":"npm","description":"Dummy archived to test with encoded package name on npm"}
22
{"name":"archived-npm-dummy","type":"npm","description":"Dummy archived to test with simple package name on npm"}
33
{"name":"@prefix/archived-pypi-dummy","type":"pypi","description":"Dummy archived to test with encoded package name on pypi"}
4-
{"name":"archived-pypi-dummy","type":"pypi","description":"Dummy archived to test with simple package name on pypi"}
4+
{"name":"archived_pypi_dummy","type":"pypi","description":"Dummy archived to test with simple package name on pypi"}
55
{"name":"@prefix/archived-maven-dummy","type":"maven","description":"Dummy archived to test with encoded package name on maven"}
66
{"name":"archived-maven-dummy","type":"maven","description":"Dummy archived to test with simple package name on maven"}
77
{"name":"github.com/archived-go-dummy","type":"npm","description":"Dummy archived to test with encoded package name on go"}

Diff for: data/deprecated.jsonl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{"name":"@prefix/deprecated-npm-dummy","type":"npm","description":"Dummy deprecated to test with encoded package name on npm"}
22
{"name":"deprecated-npm-dummy","type":"npm","description":"Dummy deprecated to test with simple package name on npm"}
33
{"name":"@prefix/deprecated-pypi-dummy","type":"pypi","description":"Dummy deprecated to test with encoded package name on pypi"}
4-
{"name":"deprecated-pypi-dummy","type":"pypi","description":"Dummy deprecated to test with simple package name on pypi"}
4+
{"name":"deprecated_pypi_dummy","type":"pypi","description":"Dummy deprecated to test with simple package name on pypi"}
55
{"name":"@prefix/deprecated-maven-dummy","type":"maven","description":"Dummy deprecated to test with encoded package name on maven"}
66
{"name":"deprecated-maven-dummy","type":"maven","description":"Dummy deprecated to test with simple package name on maven"}
77
{"name":"github.com/deprecated-go-dummy","type":"npm","description":"Dummy deprecated to test with encoded package name on go"}

Diff for: data/malicious.jsonl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{"name":"@prefix/malicious-npm-dummy","type":"npm","description":"Dummy malicious to test with encoded package name on npm"}
22
{"name":"malicious-npm-dummy","type":"npm","description":"Dummy malicious to test with simple package name on npm"}
33
{"name":"@prefix/malicious-pypi-dummy","type":"pypi","description":"Dummy malicious to test with encoded package name on pypi"}
4-
{"name":"malicious-pypi-dummy","type":"pypi","description":"Dummy malicious to test with simple package name on pypi"}
4+
{"name":"malicious_pypi_dummy","type":"pypi","description":"Dummy malicious to test with simple package name on pypi"}
55
{"name":"@prefix/malicious-maven-dummy","type":"maven","description":"Dummy malicious to test with encoded package name on maven"}
66
{"name":"malicious-maven-dummy","type":"maven","description":"Dummy malicious to test with simple package name on maven"}
77
{"name":"github.com/malicious-go-dummy","type":"go","description":"Dummy malicious to test with encoded package name on go"}

Diff for: poetry.lock

+54-56
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ tree-sitter-python = ">=0.23.6"
2828
tree-sitter-rust = ">=0.23.2"
2929
sqlite-vec-sl-tmp = "^0.0.4"
3030
alembic = ">=1.14.0"
31+
pygments = "^2.19.1"
3132

3233
[tool.poetry.group.dev.dependencies]
3334
pytest = ">=7.4.0"

Diff for: src/codegate/pipeline/base.py

+48-2
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,54 @@ def get_last_user_message(
231231
return None
232232
for i in reversed(range(len(request["messages"]))):
233233
if request["messages"][i]["role"] == "user":
234-
content = request["messages"][i]["content"]
235-
return content, i
234+
content = request["messages"][i]["content"] # type: ignore
235+
return str(content), i
236+
237+
return None
238+
239+
@staticmethod
240+
def get_last_user_message_block(
241+
request: ChatCompletionRequest,
242+
) -> Optional[str]:
243+
"""
244+
Get the last block of consecutive 'user' messages from the request.
245+
246+
Args:
247+
request (ChatCompletionRequest): The chat completion request to process
248+
249+
Returns:
250+
Optional[str]: A string containing all consecutive user messages in the
251+
last user message block, separated by newlines, or None if
252+
no user message block is found.
253+
"""
254+
if request.get("messages") is None:
255+
return None
256+
257+
user_messages = []
258+
messages = request["messages"]
259+
260+
# Iterate in reverse to find the last block of consecutive 'user' messages
261+
for i in reversed(range(len(messages))):
262+
if messages[i]["role"] == "user" or messages[i]["role"] == "assistant":
263+
content_str = None
264+
if "content" in messages[i]:
265+
content_str = messages[i]["content"] # type: ignore
266+
else:
267+
continue
268+
269+
if messages[i]["role"] == "user":
270+
user_messages.append(content_str)
271+
# specifically for Aider, when "ok." block is found, stop
272+
if content_str == "Ok." and messages[i]["role"] == "assistant":
273+
break
274+
else:
275+
# Stop when a message with a different role is encountered
276+
if user_messages:
277+
break
278+
279+
# Reverse the collected user messages to preserve the original order
280+
if user_messages:
281+
return "\n".join(reversed(user_messages))
236282

237283
return None
238284

Diff for: src/codegate/pipeline/codegate_context_retriever/codegate.py

+25-16
Original file line numberDiff line numberDiff line change
@@ -59,44 +59,53 @@ async def process(
5959
"""
6060
Use RAG DB to add context to the user request
6161
"""
62-
# Get the latest user messages
63-
user_messages = self.get_latest_user_messages(request)
64-
65-
# Nothing to do if the user_messages string is empty
66-
if len(user_messages) == 0:
62+
# Get the latest user message
63+
user_message = self.get_last_user_message_block(request)
64+
if not user_message:
6765
return PipelineResult(request=request)
6866

6967
# Create storage engine object
7068
storage_engine = StorageEngine()
7169

7270
# Extract any code snippets
73-
snippets = extract_snippets(user_messages)
71+
snippets = extract_snippets(user_message)
7472

7573
bad_snippet_packages = []
7674
if len(snippets) > 0:
75+
snippet_language = snippets[0].language
7776
# Collect all packages referenced in the snippets
7877
snippet_packages = []
7978
for snippet in snippets:
8079
snippet_packages.extend(
81-
PackageExtractor.extract_packages(snippet.code, snippet.language)
80+
PackageExtractor.extract_packages(snippet.code, snippet.language) # type: ignore
8281
)
83-
logger.info(f"Found {len(snippet_packages)} packages in code snippets.")
8482

83+
logger.info(
84+
f"Found {len(snippet_packages)} packages "
85+
f"for language {snippet_language} in code snippets."
86+
)
8587
# Find bad packages in the snippets
8688
bad_snippet_packages = await storage_engine.search(
87-
language=snippets[0].language, packages=snippet_packages
88-
)
89+
language=snippet_language, packages=snippet_packages
90+
) # type: ignore
8991
logger.info(f"Found {len(bad_snippet_packages)} bad packages in code snippets.")
9092

9193
# Remove code snippets from the user messages and search for bad packages
9294
# in the rest of the user query/messsages
93-
user_messages = re.sub(r"```.*?```", "", user_messages, flags=re.DOTALL)
94-
95-
# Vector search to find bad packages
96-
bad_packages = await storage_engine.search(query=user_messages, distance=0.5, limit=100)
95+
user_messages = re.sub(r"```.*?```", "", user_message, flags=re.DOTALL)
96+
user_messages = re.sub(r"⋮...*?⋮...\n\n", "", user_messages, flags=re.DOTALL)
97+
98+
# split messages into double newlines, to avoid passing so many content in the search
99+
split_messages = user_messages.split("\n\n")
100+
collected_bad_packages = []
101+
for item_message in split_messages:
102+
# Vector search to find bad packages
103+
bad_packages = await storage_engine.search(query=item_message, distance=0.5, limit=100)
104+
if bad_packages and len(bad_packages) > 0:
105+
collected_bad_packages.extend(bad_packages)
97106

98107
# All bad packages
99-
all_bad_packages = bad_snippet_packages + bad_packages
108+
all_bad_packages = bad_snippet_packages + collected_bad_packages
100109

101110
logger.info(f"Adding {len(all_bad_packages)} bad packages to the context.")
102111

@@ -119,7 +128,7 @@ async def process(
119128
# Add the context to the last user message
120129
# Format: "Context: {context_str} \n Query: {last user message content}"
121130
message = new_request["messages"][last_user_idx]
122-
context_msg = f'Context: {context_str} \n\n Query: {message["content"]}'
131+
context_msg = f'Context: {context_str} \n\n Query: {message["content"]}' # type: ignore
123132
message["content"] = context_msg
124133

125134
logger.debug("Final context message", context_message=context_msg)

Diff for: src/codegate/pipeline/extract_snippets/extract_snippets.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import structlog
66
from litellm.types.llms.openai import ChatCompletionRequest
7+
from pygments.lexers import guess_lexer
78

89
from codegate.pipeline.base import CodeSnippet, PipelineContext, PipelineResult, PipelineStep
910

@@ -65,6 +66,8 @@ def ecosystem_from_message(message: str) -> Optional[str]:
6566
"ts": "typescript",
6667
"tsx": "typescript",
6768
"go": "go",
69+
"rs": "rust",
70+
"java": "java",
6871
}
6972
return language_mapping.get(message, None)
7073

@@ -82,6 +85,7 @@ def extract_snippets(message: str) -> List[CodeSnippet]:
8285
# Regular expression to find code blocks
8386

8487
snippets: List[CodeSnippet] = []
88+
available_languages = ["python", "javascript", "typescript", "go", "rust", "java"]
8589

8690
# Find all code block matches
8791
for match in CODE_BLOCK_PATTERN.finditer(message):
@@ -105,6 +109,14 @@ def extract_snippets(message: str) -> List[CodeSnippet]:
105109
filename = filename.strip()
106110
# Determine language from the filename
107111
lang = ecosystem_from_filepath(filename)
112+
if lang is None:
113+
# try to guess it from the code
114+
lexer = guess_lexer(content)
115+
if lexer and lexer.name:
116+
lang = lexer.name.lower()
117+
# only add available languages
118+
if lang not in available_languages:
119+
lang = None
108120

109121
snippets.append(CodeSnippet(filepath=filename, code=content, language=lang))
110122

@@ -129,10 +141,9 @@ async def process(
129141
request: ChatCompletionRequest,
130142
context: PipelineContext,
131143
) -> PipelineResult:
132-
last_user_message = self.get_last_user_message(request)
133-
if not last_user_message:
144+
msg_content = self.get_last_user_message_block(request)
145+
if not msg_content:
134146
return PipelineResult(request=request, context=context)
135-
msg_content, _ = last_user_message
136147
snippets = extract_snippets(msg_content)
137148

138149
logger.info(f"Extracted {len(snippets)} code snippets from the user message")

0 commit comments

Comments
 (0)