Skip to content

Commit

Permalink
v1 of new search agent (#4094)
Browse files Browse the repository at this point in the history
new search agent

Co-authored-by: Martin Ye <[email protected]>
  • Loading branch information
MartinYe1234 and Martin Ye authored Jun 21, 2024
1 parent 14c8e63 commit 2fc2496
Show file tree
Hide file tree
Showing 7 changed files with 442 additions and 25 deletions.
2 changes: 1 addition & 1 deletion sweepai/agents/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_xml(
{parameters_xml}
</{self.name}>"""
if include_function_call_tags:
function_xml += f"<function_call>\n{function_xml}\n</function_call>"
function_xml = f"<function_call>\n{function_xml}\n</function_call>"
if include_description and self.description:
function_xml = f"{self.name} - {self.description}\n\n{function_xml}"
return function_xml
Expand Down
406 changes: 400 additions & 6 deletions sweepai/agents/question_answerer.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion sweepai/config/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def is_file_bad(self, file_name: str, repo_dir: str) -> tuple[bool, str]:
if bool(match):
return True, "The filename means that this file is likely auto generated."
except Exception as e:
logger.error(f"Error when checking if file is autogenerated: {e}, run `sudo apt-get install cmake pkg-config libicu-dev zlib1g-dev libcurl4-openssl-dev libssl-dev ruby-dev && gem install github-linguist`")
logger.error(f"Error when checking if file {file_name} is autogenerated: {e}, run `sudo apt-get install cmake pkg-config libicu-dev zlib1g-dev libcurl4-openssl-dev libssl-dev ruby-dev && gem install github-linguist`")
posthog.capture(
"is_file_auto_generated_or_vendored",
"is_file_auto_generated_or_vendored error",
Expand Down
19 changes: 18 additions & 1 deletion sweepai/core/snippet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,21 @@ def convert_lines_to_and_merge_ranges(
ranges[-1] = (previous_start, end)
else:
ranges.append((start, end))
return ranges
return ranges


def merge_snippet_ranges(ranges: list[tuple[int, int]]) -> list[tuple[int, int]]:
"""
Merges overlapping ranges
"""
if not ranges:
return []
ranges.sort()
merged_ranges = [ranges[0]]
for current_start, current_end in ranges[1:]:
previous_start, previous_end = merged_ranges[-1]
if current_start <= previous_end:
merged_ranges[-1] = (previous_start, max(previous_end, current_end))
else:
merged_ranges.append((current_start, current_end))
return merged_ranges
25 changes: 18 additions & 7 deletions sweepai/utils/cohere_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import backoff
from loguru import logger
import voyageai
import cohere
from sweepai.config.server import COHERE_API_KEY, VOYAGE_API_KEY
from sweepai.logn.cache import file_cache


@backoff.on_exception(
backoff.expo,
Exception,
max_tries=3,
jitter=backoff.random_jitter,
)
@file_cache()
def cohere_rerank_call(
query: str,
Expand All @@ -13,12 +20,16 @@ def cohere_rerank_call(
):
# Cohere API call with caching
co = cohere.Client(COHERE_API_KEY)
return co.rerank(
model=model,
query=query,
documents=documents,
**kwargs
)
try:
return co.rerank(
model=model,
query=query,
documents=documents,
**kwargs
)
except Exception as e:
logger.error(f"Cohere rerank failed: {e}")
raise e

@file_cache()
def voyage_rerank_call(
Expand Down
6 changes: 1 addition & 5 deletions sweepai/utils/github_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,11 +689,7 @@ def dfs_helper(directory):
return files

def get_file_contents(self, file_path, ref=None):
local_path = (
f"{self.repo_dir}{file_path}"
if file_path.startswith("/")
else f"{self.repo_dir}/{file_path}"
)
local_path = os.path.join(self.repo_dir, file_path.lstrip("/"))
if os.path.exists(local_path) and os.path.isfile(local_path):
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
contents = f.read()
Expand Down
7 changes: 3 additions & 4 deletions sweepai/utils/ticket_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from loguru import logger
from tqdm import tqdm
import networkx as nx
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.streamable_functions import streamable

from sweepai.utils.timer import Timer
Expand Down Expand Up @@ -509,10 +510,8 @@ def fetch_relevant_files(
username,
metadata,
on_ticket_start_time,
tracking_id,
is_paying_user,
issue_url,
chat_logger,
tracking_id: str = "",
chat_logger: ChatLogger | None = None,
images = None
):
logger.info("Fetching relevant files...")
Expand Down

0 comments on commit 2fc2496

Please sign in to comment.