Skip to content

Commit

Permalink
Ensure the happy path works well (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
iuliaturc authored Oct 23, 2024
1 parent a3f3f19 commit 0fe88f8
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 110 deletions.
49 changes: 49 additions & 0 deletions sage/code_symbols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Utilities to extract code symbols (class and method names) from code files."""

import logging
from typing import List, Tuple

from tree_sitter import Node

from sage.chunker import CodeFileChunker


def _extract_classes_and_methods(node: Node, acc: List[Tuple[str, str]], parent_class: str = None):
"""Extracts classes and methods from a tree-sitter node and places them in the `acc` accumulator."""
if node.type in ["class_definition", "class_declaration"]:
class_name_node = node.child_by_field_name("name")
if class_name_node:
class_name = class_name_node.text.decode("utf-8")
acc.append((class_name, None))
for child in node.children:
_extract_classes_and_methods(child, acc, class_name)
elif node.type in ["function_definition", "method_definition"]:
function_name_node = node.child_by_field_name("name")
if function_name_node:
acc.append((parent_class, function_name_node.text.decode("utf-8")))
# We're not going deeper into a method. This means we're missing nested functions.
else:
for child in node.children:
_extract_classes_and_methods(child, acc, parent_class)


def get_code_symbols(file_path: str, content: str) -> List[Tuple[str, str]]:
"""Extracts code symbols from a file.
Code symbols are tuples of the form (class_name, method_name). For classes, method_name is None. For methods
that do not belong to a class, class_name is None.
"""
if not CodeFileChunker.is_code_file(file_path):
return []

if not content:
return []

logging.info(f"Extracting code symbols from {file_path}")
tree = CodeFileChunker.parse_tree(file_path, content)
if not tree:
return []

classes_and_methods = []
_extract_classes_and_methods(tree.root_node, classes_and_methods)
return classes_and_methods
5 changes: 4 additions & 1 deletion sage/configs/remote.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
llm-retriever: true
llm-provider: anthropic
reranker-provider: anthropic
# Here we optimize for ease of setup, so we skip the reranker which would require an extra API key.
reranker-provider: none
# Since we skipped the reranker, we can't afford to feed the retriever with too many candidates.
retriever-top-k: 5

# The settings below (embeddings and vector store) are only relevant when setting --no-llm-retriever

Expand Down
21 changes: 11 additions & 10 deletions sage/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,8 @@ def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, No
yield metadata
continue

with open(file_path, "r") as f:
try:
contents = f.read()
except UnicodeDecodeError:
logging.warning("Unable to decode file %s. Skipping.", file_path)
continue
contents = self.read_file(relative_file_path)
if contents:
yield contents, metadata

def url_for_file(self, file_path: str) -> str:
Expand All @@ -231,10 +227,15 @@ def url_for_file(self, file_path: str) -> str:
return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"

def read_file(self, relative_file_path: str) -> str:
"""Reads the content of the file at the given path."""
file_path = os.path.join(self.local_dir, relative_file_path)
with open(file_path, "r") as f:
return f.read()
"""Reads the contents of a file in the repository."""
absolute_file_path = os.path.join(self.local_dir, relative_file_path)
with open(absolute_file_path, "r") as f:
try:
contents = f.read()
return contents
except UnicodeDecodeError:
logging.warning("Unable to decode file %s.", absolute_file_path)
return None

def from_args(args: Dict):
"""Creates a GitHubRepoManager from command-line arguments and clones the underlying repository."""
Expand Down
66 changes: 2 additions & 64 deletions sage/reranker.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
import logging
import os
from enum import Enum
from typing import List, Optional
from typing import Optional

from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_cohere import CohereRerank
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_community.document_compressors import JinaRerank
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.documents import BaseDocumentCompressor
from langchain_nvidia_ai_endpoints import NVIDIARerank
from langchain_voyageai import VoyageAIRerank
from pydantic import ConfigDict, Field

from sage.llm import build_llm_via_langchain


class RerankerProvider(Enum):
Expand All @@ -25,58 +18,6 @@ class RerankerProvider(Enum):
NVIDIA = "nvidia"
JINA = "jina"
VOYAGE = "voyage"
# Anthropic doesn't provide an explicit reranker; we simply prompt the LLM with the user query and the content of
# the top k documents.
ANTHROPIC = "anthropic"


class LLMReranker(BaseDocumentCompressor):
"""Reranker that passes the user query and top N documents to a language model to order them.
Note that Langchain's RerankLLM does not support LLMs from Anthropic.
https://python.langchain.com/api_reference/community/document_compressors/langchain_community.document_compressors.rankllm_rerank.RankLLMRerank.html
Also, they rely on https://github.com/castorini/rank_llm, which doesn't run on Apple Silicon (M1/M2 chips).
"""

llm: BaseLanguageModel = Field(...)
top_k: int = Field(...)

model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)

@property
def prompt(self):
return PromptTemplate.from_template(
"Given the following query: '{query}'\n\n"
"And these documents:\n\n{documents}\n\n"
"Rank the documents based on their relevance to the query. "
"Return only the document numbers in order of relevance, separated by commas. For example: 2,5,1,3,4. "
"Return absolutely nothing else."
)

def compress_documents(
self,
documents: List[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> List[Document]:
if len(documents) <= self.top_k:
return documents

doc_texts = [f"Document {i+1}:\n{doc.page_content}\n" for i, doc in enumerate(documents)]
docs_str = "\n".join(doc_texts)

llm_input = self.prompt.format(query=query, documents=docs_str)
result = self.llm.predict(llm_input)

try:
ranked_indices = [int(idx) - 1 for idx in result.strip().split(",")][: self.top_k]
return [documents[i] for i in ranked_indices]
except ValueError:
logging.warning("Failed to parse reranker output. Returning original order. LLM responded with: %s", result)
return documents[: self.top_k]


def build_reranker(provider: str, model: Optional[str] = None, top_k: Optional[int] = 5) -> BaseDocumentCompressor:
Expand Down Expand Up @@ -105,7 +46,4 @@ def build_reranker(provider: str, model: Optional[str] = None, top_k: Optional[i
raise ValueError("Please set the VOYAGE_API_KEY environment variable")
model = model or "rerank-1"
return VoyageAIRerank(model=model, api_key=os.environ.get("VOYAGE_API_KEY"), top_k=top_k)
if provider == RerankerProvider.ANTHROPIC.value:
llm = build_llm_via_langchain("anthropic", model)
return LLMReranker(llm=llm, top_k=1)
raise ValueError(f"Invalid reranker provider: {provider}")
Loading

0 comments on commit 0fe88f8

Please sign in to comment.