Skip to content

Commit

Permalink
Adapt to Anthropic's new count_tokens API
Browse files Browse the repository at this point in the history
  • Loading branch information
Julia Turc authored and Julia Turc committed Nov 11, 2024
1 parent a667945 commit 801a9b2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
5 changes: 4 additions & 1 deletion sage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def add_vector_store_args(parser: ArgumentParser) -> Callable:
parser.add(
"--llm-retriever",
action=argparse.BooleanOptionalAction,
default=False,
default=True,
help="When set to True, we use an LLM for retrieval: we pass the repository file hierarchy together with the "
"user query and ask the LLM to choose relevant files solely based on their paths. No indexing will be done, so "
"all the vector store / embedding arguments will be ignored.",
Expand Down Expand Up @@ -358,6 +358,9 @@ def _validate_gemini_embedding_args(args):

def validate_embedding_args(args):
"""Validates the configuration of the batch embedder and sets defaults."""
if args.llm_retriever:
# When using an LLM to retrieve, we are not running the embedder.
return True
if args.embedding_provider == "openai":
_validate_openai_embedding_args(args)
elif args.embedding_provider == "voyage":
Expand Down
12 changes: 8 additions & 4 deletions sage/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,22 +91,26 @@ def repo_hierarchy(self):
render = LLMRetriever._render_file_hierarchy(self.repo_metadata, include_classes=True, include_methods=True)
max_tokens = CLAUDE_MODEL_CONTEXT_SIZE - 50_000 # 50,000 tokens for other parts of the prompt.
client = anthropic.Anthropic()
if client.count_tokens(render) > max_tokens:

def count_tokens(x):
count = client.beta.messages.count_tokens(model=CLAUDE_MODEL, messages=[{"role": "user", "content": x}])
return count.input_tokens

if count_tokens(render) > max_tokens:
logging.info("File hierarchy is too large; excluding methods.")
render = LLMRetriever._render_file_hierarchy(
self.repo_metadata, include_classes=True, include_methods=False
)
if client.count_tokens(render) > max_tokens:
if count_tokens(render) > max_tokens:
logging.info("File hierarchy is still too large; excluding classes.")
render = LLMRetriever._render_file_hierarchy(
self.repo_metadata, include_classes=False, include_methods=False
)
if client.count_tokens(render) > max_tokens:
if count_tokens(render) > max_tokens:
logging.info("File hierarchy is still too large; truncating.")
tokenizer = anthropic.Tokenizer()
tokens = tokenizer.tokenize(render)[:max_tokens]
render = tokenizer.detokenize(tokens)
logging.info("Number of tokens in render hierarchy: %d", client.count_tokens(render))
self.cached_repo_hierarchy = render
return self.cached_repo_hierarchy

Expand Down

0 comments on commit 801a9b2

Please sign in to comment.