From 801a9b251d87539c6d9e036a4572147c44669029 Mon Sep 17 00:00:00 2001 From: Julia Turc Date: Sun, 10 Nov 2024 19:44:23 -0800 Subject: [PATCH] Adapt to Anthropic's new count_tokens API --- sage/config.py | 5 ++++- sage/retriever.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/sage/config.py b/sage/config.py index 3ec21d9..00a92e7 100644 --- a/sage/config.py +++ b/sage/config.py @@ -163,7 +163,7 @@ def add_vector_store_args(parser: ArgumentParser) -> Callable: parser.add( "--llm-retriever", action=argparse.BooleanOptionalAction, - default=False, + default=True, help="When set to True, we use an LLM for retrieval: we pass the repository file hierarchy together with the " "user query and ask the LLM to choose relevant files solely based on their paths. No indexing will be done, so " "all the vector store / embedding arguments will be ignored.", @@ -358,6 +358,9 @@ def _validate_gemini_embedding_args(args): def validate_embedding_args(args): """Validates the configuration of the batch embedder and sets defaults.""" + if args.llm_retriever: + # When using an LLM to retrieve, we are not running the embedder. + return True if args.embedding_provider == "openai": _validate_openai_embedding_args(args) elif args.embedding_provider == "voyage": diff --git a/sage/retriever.py b/sage/retriever.py index 4275775..e718134 100644 --- a/sage/retriever.py +++ b/sage/retriever.py @@ -91,22 +91,26 @@ def repo_hierarchy(self): render = LLMRetriever._render_file_hierarchy(self.repo_metadata, include_classes=True, include_methods=True) max_tokens = CLAUDE_MODEL_CONTEXT_SIZE - 50_000 # 50,000 tokens for other parts of the prompt. client = anthropic.Anthropic() - if client.count_tokens(render) > max_tokens: + + def count_tokens(x): + count = client.beta.messages.count_tokens(model=CLAUDE_MODEL, messages=[{"role": "user", "content": x}]) + return count.input_tokens + + if count_tokens(render) > max_tokens: logging.info("File hierarchy is too large; excluding methods.") render = LLMRetriever._render_file_hierarchy( self.repo_metadata, include_classes=True, include_methods=False ) - if client.count_tokens(render) > max_tokens: + if count_tokens(render) > max_tokens: logging.info("File hierarchy is still too large; excluding classes.") render = LLMRetriever._render_file_hierarchy( self.repo_metadata, include_classes=False, include_methods=False ) - if client.count_tokens(render) > max_tokens: + if count_tokens(render) > max_tokens: logging.info("File hierarchy is still too large; truncating.") tokenizer = anthropic.Tokenizer() tokens = tokenizer.tokenize(render)[:max_tokens] render = tokenizer.detokenize(tokens) - logging.info("Number of tokens in render hierarchy: %d", client.count_tokens(render)) self.cached_repo_hierarchy = render return self.cached_repo_hierarchy