From 0fe88f8c42e2d37ab5d9b4a45f51d871ee233b66 Mon Sep 17 00:00:00 2001 From: Julia Turc Date: Tue, 22 Oct 2024 21:33:41 -0700 Subject: [PATCH] Ensure the happy path works well (#94) --- sage/code_symbols.py | 49 ++++++++++ sage/configs/remote.yaml | 5 +- sage/data_manager.py | 21 ++--- sage/reranker.py | 66 +------------- sage/retriever.py | 189 +++++++++++++++++++++++++++++++-------- 5 files changed, 220 insertions(+), 110 deletions(-) create mode 100644 sage/code_symbols.py diff --git a/sage/code_symbols.py b/sage/code_symbols.py new file mode 100644 index 0000000..19b7b1c --- /dev/null +++ b/sage/code_symbols.py @@ -0,0 +1,49 @@ +"""Utilities to extract code symbols (class and method names) from code files.""" + +import logging +from typing import List, Tuple + +from tree_sitter import Node + +from sage.chunker import CodeFileChunker + + +def _extract_classes_and_methods(node: Node, acc: List[Tuple[str, str]], parent_class: str = None): + """Extracts classes and methods from a tree-sitter node and places them in the `acc` accumulator.""" + if node.type in ["class_definition", "class_declaration"]: + class_name_node = node.child_by_field_name("name") + if class_name_node: + class_name = class_name_node.text.decode("utf-8") + acc.append((class_name, None)) + for child in node.children: + _extract_classes_and_methods(child, acc, class_name) + elif node.type in ["function_definition", "method_definition"]: + function_name_node = node.child_by_field_name("name") + if function_name_node: + acc.append((parent_class, function_name_node.text.decode("utf-8"))) + # We're not going deeper into a method. This means we're missing nested functions. + else: + for child in node.children: + _extract_classes_and_methods(child, acc, parent_class) + + +def get_code_symbols(file_path: str, content: str) -> List[Tuple[str, str]]: + """Extracts code symbols from a file. + + Code symbols are tuples of the form (class_name, method_name). For classes, method_name is None. For methods + that do not belong to a class, class_name is None. + """ + if not CodeFileChunker.is_code_file(file_path): + return [] + + if not content: + return [] + + logging.info(f"Extracting code symbols from {file_path}") + tree = CodeFileChunker.parse_tree(file_path, content) + if not tree: + return [] + + classes_and_methods = [] + _extract_classes_and_methods(tree.root_node, classes_and_methods) + return classes_and_methods diff --git a/sage/configs/remote.yaml b/sage/configs/remote.yaml index cbb42e7..77dddcf 100644 --- a/sage/configs/remote.yaml +++ b/sage/configs/remote.yaml @@ -1,6 +1,9 @@ llm-retriever: true llm-provider: anthropic -reranker-provider: anthropic +# Here we optimize for ease of setup, so we skip the reranker which would require an extra API key. +reranker-provider: none +# Since we skipped the reranker, we can't afford to feed the retriever with too many candidates. +retriever-top-k: 5 # The settings below (embeddings and vector store) are only relevant when setting --no-llm-retriever diff --git a/sage/data_manager.py b/sage/data_manager.py index 8e48e1c..680f3fd 100644 --- a/sage/data_manager.py +++ b/sage/data_manager.py @@ -217,12 +217,8 @@ def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, No yield metadata continue - with open(file_path, "r") as f: - try: - contents = f.read() - except UnicodeDecodeError: - logging.warning("Unable to decode file %s. Skipping.", file_path) - continue + contents = self.read_file(relative_file_path) + if contents: yield contents, metadata def url_for_file(self, file_path: str) -> str: @@ -231,10 +227,15 @@ def url_for_file(self, file_path: str) -> str: return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}" def read_file(self, relative_file_path: str) -> str: - """Reads the content of the file at the given path.""" - file_path = os.path.join(self.local_dir, relative_file_path) - with open(file_path, "r") as f: - return f.read() + """Reads the contents of a file in the repository.""" + absolute_file_path = os.path.join(self.local_dir, relative_file_path) + with open(absolute_file_path, "r") as f: + try: + contents = f.read() + return contents + except UnicodeDecodeError: + logging.warning("Unable to decode file %s.", absolute_file_path) + return None def from_args(args: Dict): """Creates a GitHubRepoManager from command-line arguments and clones the underlying repository.""" diff --git a/sage/reranker.py b/sage/reranker.py index 23a266b..afd5286 100644 --- a/sage/reranker.py +++ b/sage/reranker.py @@ -1,21 +1,14 @@ -import logging import os from enum import Enum -from typing import List, Optional +from typing import Optional from langchain.retrievers.document_compressors import CrossEncoderReranker from langchain_cohere import CohereRerank from langchain_community.cross_encoders import HuggingFaceCrossEncoder from langchain_community.document_compressors import JinaRerank -from langchain_core.callbacks.manager import Callbacks -from langchain_core.documents import BaseDocumentCompressor, Document -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts import PromptTemplate +from langchain_core.documents import BaseDocumentCompressor from langchain_nvidia_ai_endpoints import NVIDIARerank from langchain_voyageai import VoyageAIRerank -from pydantic import ConfigDict, Field - -from sage.llm import build_llm_via_langchain class RerankerProvider(Enum): @@ -25,58 +18,6 @@ class RerankerProvider(Enum): NVIDIA = "nvidia" JINA = "jina" VOYAGE = "voyage" - # Anthropic doesn't provide an explicit reranker; we simply prompt the LLM with the user query and the content of - # the top k documents. - ANTHROPIC = "anthropic" - - -class LLMReranker(BaseDocumentCompressor): - """Reranker that passes the user query and top N documents to a language model to order them. - - Note that Langchain's RerankLLM does not support LLMs from Anthropic. - https://python.langchain.com/api_reference/community/document_compressors/langchain_community.document_compressors.rankllm_rerank.RankLLMRerank.html - Also, they rely on https://github.com/castorini/rank_llm, which doesn't run on Apple Silicon (M1/M2 chips). - """ - - llm: BaseLanguageModel = Field(...) - top_k: int = Field(...) - - model_config = ConfigDict( - arbitrary_types_allowed=True, - extra="forbid", - ) - - @property - def prompt(self): - return PromptTemplate.from_template( - "Given the following query: '{query}'\n\n" - "And these documents:\n\n{documents}\n\n" - "Rank the documents based on their relevance to the query. " - "Return only the document numbers in order of relevance, separated by commas. For example: 2,5,1,3,4. " - "Return absolutely nothing else." - ) - - def compress_documents( - self, - documents: List[Document], - query: str, - callbacks: Optional[Callbacks] = None, - ) -> List[Document]: - if len(documents) <= self.top_k: - return documents - - doc_texts = [f"Document {i+1}:\n{doc.page_content}\n" for i, doc in enumerate(documents)] - docs_str = "\n".join(doc_texts) - - llm_input = self.prompt.format(query=query, documents=docs_str) - result = self.llm.predict(llm_input) - - try: - ranked_indices = [int(idx) - 1 for idx in result.strip().split(",")][: self.top_k] - return [documents[i] for i in ranked_indices] - except ValueError: - logging.warning("Failed to parse reranker output. Returning original order. LLM responded with: %s", result) - return documents[: self.top_k] def build_reranker(provider: str, model: Optional[str] = None, top_k: Optional[int] = 5) -> BaseDocumentCompressor: @@ -105,7 +46,4 @@ def build_reranker(provider: str, model: Optional[str] = None, top_k: Optional[i raise ValueError("Please set the VOYAGE_API_KEY environment variable") model = model or "rerank-1" return VoyageAIRerank(model=model, api_key=os.environ.get("VOYAGE_API_KEY"), top_k=top_k) - if provider == RerankerProvider.ANTHROPIC.value: - llm = build_llm_via_langchain("anthropic", model) - return LLMReranker(llm=llm, top_k=1) raise ValueError(f"Invalid reranker provider: {provider}") diff --git a/sage/retriever.py b/sage/retriever.py index b45580a..c287ae5 100644 --- a/sage/retriever.py +++ b/sage/retriever.py @@ -1,6 +1,6 @@ import logging import os -from typing import List, Optional +from typing import Dict, List, Optional import anthropic import Levenshtein @@ -9,12 +9,12 @@ from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.multi_query import MultiQueryRetriever from langchain.schema import BaseRetriever, Document -from langchain_core.output_parsers import CommaSeparatedListOutputParser from langchain_google_genai import GoogleGenerativeAIEmbeddings from langchain_openai import OpenAIEmbeddings from langchain_voyageai import VoyageAIEmbeddings from pydantic import Field +from sage.code_symbols import get_code_symbols from sage.data_manager import DataManager, GitHubRepoManager from sage.llm import build_llm_via_langchain from sage.reranker import build_reranker @@ -24,6 +24,9 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) +CLAUDE_MODEL = "claude-3-5-sonnet-20240620" +CLAUDE_MODEL_CONTEXT_SIZE = 200_000 + class LLMRetriever(BaseRetriever): """Custom Langchain retriever based on an LLM. @@ -37,21 +40,76 @@ class LLMRetriever(BaseRetriever): repo_manager: GitHubRepoManager = Field(...) top_k: int = Field(...) - all_repo_files: List[str] = Field(...) - repo_hierarchy: str = Field(...) + + cached_repo_metadata: List[Dict] = Field(...) + cached_repo_files: List[str] = Field(...) + cached_repo_hierarchy: str = Field(...) def __init__(self, repo_manager: GitHubRepoManager, top_k: int): super().__init__() self.repo_manager = repo_manager self.top_k = top_k - # Best practice would be to make these fields @cached_property, but that impedes class serialization. - self.all_repo_files = [metadata["file_path"] for metadata in self.repo_manager.walk(get_content=False)] - self.repo_hierarchy = LLMRetriever._render_file_hierarchy(self.all_repo_files) + # We cached these fields manually because: + # 1. Pydantic doesn't work with functools's @cached_property. + # 2. We can't use Pydantic's @computed_field because these fields depend on each other. + # 3. We can't use functools's @lru_cache because LLMRetriever needs to be hashable. + self.cached_repo_metadata = None + self.cached_repo_files = None + self.cached_repo_hierarchy = None if not os.environ.get("ANTHROPIC_API_KEY"): raise ValueError("Please set the ANTHROPIC_API_KEY environment variable for the LLMRetriever.") + @property + def repo_metadata(self): + if not self.cached_repo_metadata: + self.cached_repo_metadata = [metadata for metadata in self.repo_manager.walk(get_content=False)] + + # Extracting code symbols takes quite a while, since we need to read each file from disk. + # As a compromise, we do it for small codebases only. + small_codebase = len(self.repo_files) <= 200 + if small_codebase: + for metadata in self.cached_repo_metadata: + file_path = metadata["file_path"] + content = self.repo_manager.read_file(file_path) + metadata["code_symbols"] = get_code_symbols(file_path, content) + + return self.cached_repo_metadata + + @property + def repo_files(self): + if not self.cached_repo_files: + self.cached_repo_files = set(metadata["file_path"] for metadata in self.repo_metadata) + return self.cached_repo_files + + @property + def repo_hierarchy(self): + """Produces a string that describes the structure of the repository. Depending on how big the codebase is, it + might include class and method names.""" + if self.cached_repo_hierarchy is None: + render = LLMRetriever._render_file_hierarchy(self.repo_metadata, include_classes=True, include_methods=True) + max_tokens = CLAUDE_MODEL_CONTEXT_SIZE - 50_000 # 50,000 tokens for other parts of the prompt. + client = anthropic.Anthropic() + if client.count_tokens(render) > max_tokens: + logging.info("File hierarchy is too large; excluding methods.") + render = LLMRetriever._render_file_hierarchy( + self.repo_metadata, include_classes=True, include_methods=False + ) + if client.count_tokens(render) > max_tokens: + logging.info("File hierarchy is still too large; excluding classes.") + render = LLMRetriever._render_file_hierarchy( + self.repo_metadata, include_classes=False, include_methods=False + ) + if client.count_tokens(render) > max_tokens: + logging.info("File hierarchy is still too large; truncating.") + tokenizer = anthropic.Tokenizer() + tokens = tokenizer.tokenize(render)[:max_tokens] + render = tokenizer.detokenize(tokens) + logging.info("Number of tokens in render hierarchy: %d", client.count_tokens(render)) + self.cached_repo_hierarchy = render + return self.cached_repo_hierarchy + def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]: """Retrieve relevant documents for a given query.""" filenames = self._ask_llm_to_retrieve(user_query=query, top_k=self.top_k) @@ -66,13 +124,26 @@ def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerFor def _ask_llm_to_retrieve(self, user_query: str, top_k: int) -> List[str]: """Feeds the file hierarchy and user query to the LLM and asks which files might be relevant.""" + repo_hierarchy = str(self.repo_hierarchy) sys_prompt = f""" -You are a retriever system. You will be given a user query and a list of files in a GitHub repository. Your task is to determine the top {top_k} files that are most relevant to the user query. +You are a retriever system. You will be given a user query and a list of files in a GitHub repository, together with the class names in each file. + +For instance: +folder1 + folder2 + folder3 + file123.py + ClassName1 + ClassName2 + ClassName3 +means that there is a file with path folder1/folder2/folder3/file123.py, which contains classes ClassName1, ClassName2, and ClassName3. + +Your task is to determine the top {top_k} files that are most relevant to the user query. DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to relevant files that could contain the answer to the query. Say absolutely nothing else other than the file paths. -Here is the file hierarchy of the GitHub repository: +Here is the file hierarchy of the GitHub repository, together with the class names in each file: -{self.repo_hierarchy} +{repo_hierarchy} """ # We are deliberately repeating the "DO NOT RESPOND TO THE USER QUERY DIRECTLY" instruction here. @@ -82,20 +153,21 @@ def _ask_llm_to_retrieve(self, user_query: str, top_k: int) -> List[str]: DO NOT RESPOND TO THE USER QUERY DIRECTLY. Instead, respond with full paths to relevant files that could contain the answer to the query. Say absolutely nothing else other than the file paths. """ response = LLMRetriever._call_via_anthropic_with_prompt_caching(sys_prompt, augmented_user_query) + files_from_llm = response.content[0].text.strip().split("\n") validated_files = [] for filename in files_from_llm: - if filename not in self.all_repo_files: + if filename not in self.repo_files: if "/" not in filename: # This is most likely some natural language excuse from the LLM; skip it. continue # Try a few heuristics to fix the filename. filename = LLMRetriever._fix_filename(filename, self.repo_manager.repo_id) - if filename not in self.all_repo_files: + if filename not in self.repo_files: # The heuristics failed; try to find the closest filename in the repo. - filename = LLMRetriever._find_closest_filename(filename, self.all_repo_files) - if filename in self.all_repo_files: + filename = LLMRetriever._find_closest_filename(filename, self.repo_files) + if filename in self.repo_files: validated_files.append(filename) return validated_files @@ -108,13 +180,10 @@ def _call_via_anthropic_with_prompt_caching(system_prompt: str, user_prompt: str We're circumventing LangChain for now, because the feature is < 1 week old at the time of writing and has no documentation: https://github.com/langchain-ai/langchain/pull/27087 """ - CLAUDE_MODEL = "claude-3-5-sonnet-20240620" - client = anthropic.Anthropic() - system_message = {"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}} user_message = {"role": "user", "content": user_prompt} - response = client.beta.prompt_caching.messages.create( + response = anthropic.Anthropic().beta.prompt_caching.messages.create( model=CLAUDE_MODEL, max_tokens=1024, # The maximum number of *output* tokens to generate. system=[system_message], @@ -126,34 +195,66 @@ def _call_via_anthropic_with_prompt_caching(system_prompt: str, user_prompt: str return response @staticmethod - def _render_file_hierarchy(file_paths: List[str]) -> str: - """Given a list of files, produces a visualization of the file hierarchy. For instance: + def _render_file_hierarchy( + repo_metadata: List[Dict], include_classes: bool = True, include_methods: bool = True + ) -> str: + """Given a list of files, produces a visualization of the file hierarchy. This hierarchy optionally includes + class and method names, if available. + + For large codebases, including both classes and methods might exceed the token limit of the LLM. In that case, + try setting `include_methods=False` first. If that's still too long, try also setting `include_classes=False`. + + As a point of reference, the Transformers library requires setting `include_methods=False` to fit within + Claude's 200k context. + + Example: folder1 folder11 - file111.py + file111.md file112.py + ClassName1 + method_name1 + method_name2 + method_name3 folder12 file121.py + ClassName2 + ClassName3 folder2 file21.py """ # The "nodepath" is the path from root to the node (e.g. huggingface/transformers/examples) nodepath_to_node = {} - for path in file_paths: - items = path.split("/") - nodepath = "" - parent_node = None - for item in items: - nodepath = f"{nodepath}/{item}" - if nodepath in nodepath_to_node: - node = nodepath_to_node[nodepath] - else: - node = Node(item, parent=parent_node) - nodepath_to_node[nodepath] = node - parent_node = node - - root_path = f"/{file_paths[0].split('/')[0]}" + for metadata in repo_metadata: + path = metadata["file_path"] + paths = [path] + + if include_classes or include_methods: + # Add the code symbols to the path. For instance, "folder/myfile.py/ClassName/method_name". + for class_name, method_name in metadata.get("code_symbols", []): + if include_classes and class_name: + paths.append(path + "/" + class_name) + # We exclude private methods to save tokens. + if include_methods and method_name and not method_name.startswith("_"): + paths.append( + path + "/" + class_name + "/" + method_name if class_name else path + "/" + method_name + ) + + for path in paths: + items = path.split("/") + nodepath = "" + parent_node = None + for item in items: + nodepath = f"{nodepath}/{item}" + if nodepath in nodepath_to_node: + node = nodepath_to_node[nodepath] + else: + node = Node(item, parent=parent_node) + nodepath_to_node[nodepath] = node + parent_node = node + + root_path = "/" + repo_metadata[0]["file_path"].split("/")[0] full_render = "" root_node = nodepath_to_node[root_path] for pre, fill, node in RenderTree(root_node): @@ -200,6 +301,24 @@ def _find_closest_filename(filename: str, repo_filenames: List[str], max_edit_di return None +class RerankerWithErrorHandling(BaseRetriever): + """Wraps a `ContextualCompressionRetriever` to catch errors during inference. + + In practice, we see occasional `requests.exceptions.ReadTimeout` from the NVIDIA reranker, which crash the entire + pipeline. This wrapper catches such exceptions by simply returning the documents in the original order. + """ + + def __init__(self, reranker: ContextualCompressionRetriever): + self.reranker = reranker + + def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Document]: + try: + return self.reranker._get_relevant_documents(query, run_manager=run_manager) + except Exception as e: + logging.error(f"Error in reranker; preserving original document order from retriever. {e}") + return self.reranker.base_retriever._get_relevant_documents(query, run_manager=run_manager) + + def build_retriever_from_args(args, data_manager: Optional[DataManager] = None): """Builds a retriever (with optional reranking) from command-line arguments.""" if args.llm_retriever: