Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Changes for 1.0.18 (#570)
Browse files Browse the repository at this point in the history
  • Loading branch information
granawkins authored Apr 18, 2024
1 parent 522107e commit e51fdfb
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 10 deletions.
28 changes: 21 additions & 7 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from mentat.code_feature import CodeFeature, get_consolidated_feature_refs
from mentat.diff_context import DiffContext
from mentat.errors import PathValidationError
from mentat.git_handler import get_git_root_for_path
from mentat.include_files import (
PathType,
get_code_features_for_path,
Expand Down Expand Up @@ -67,6 +68,11 @@ async def refresh_daemon(self):
cwd = ctx.cwd
llm_api_handler = ctx.llm_api_handler

# Use print because stream is not initialized yet
print("Scanning codebase for updates...")
if not get_git_root_for_path(cwd, raise_error=False):
print("\033[93mWarning: Not a git repository (this might take a while)\033[0m")

annotators: dict[str, dict[str, Any]] = {
"hierarchy": {"ignore_patterns": [str(p) for p in self.ignore_patterns]},
"chunker_line": {"lines_per_chunk": 50},
Expand Down Expand Up @@ -185,11 +191,15 @@ async def get_code_message(
auto_tokens=auto_tokens,
)
for ref in context_builder.to_refs():
new_features = list[CodeFeature]() # Save ragdaemon context back to include_files
path, interval_str = split_intervals_from_path(Path(ref))
intervals = parse_intervals(interval_str)
for interval in intervals:
feature = CodeFeature(cwd / path, interval)
self.include_features([feature]) # Save ragdaemon context back to include_files
if not interval_str:
new_features.append(CodeFeature(cwd / path))
else:
intervals = parse_intervals(interval_str)
for interval in intervals:
new_features.append(CodeFeature(cwd / path, interval))
self.include_features(new_features)

# The context message is rendered by ragdaemon (ContextBuilder.render())
context_message = context_builder.render()
Expand Down Expand Up @@ -417,10 +427,14 @@ async def search(
continue
distance = node["distance"]
path, interval = split_intervals_from_path(Path(node["ref"]))
intervals = parse_intervals(interval)
for _interval in intervals:
feature = CodeFeature(cwd / path, _interval)
if not interval:
feature = CodeFeature(cwd / path)
all_features_sorted.append((feature, distance))
else:
intervals = parse_intervals(interval)
for _interval in intervals:
feature = CodeFeature(cwd / path, _interval)
all_features_sorted.append((feature, distance))
if max_results is None:
return all_features_sorted
else:
Expand Down
2 changes: 1 addition & 1 deletion mentat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Config:
)
provider: Optional[str] = attr.field(default=None, metadata={"auto_completions": ["openai", "anthropic", "azure"]})
embedding_model: str = attr.field(
default="text-embedding-ada-002",
default="text-embedding-3-large",
metadata={"auto_completions": [model.name for model in models if isinstance(model, EmbeddingModel)]},
)
embedding_provider: Optional[str] = attr.field(
Expand Down
2 changes: 1 addition & 1 deletion mentat/llm_api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ async def call_llm_api(
return response

@api_guard
def call_embedding_api(self, input_texts: list[str], model: str = "text-embedding-ada-002") -> EmbeddingResponse:
def call_embedding_api(self, input_texts: list[str], model: str = "text-embedding-3-large") -> EmbeddingResponse:
ctx = SESSION_CONTEXT.get()
return self.spice.get_embeddings_sync(input_texts, model, provider=ctx.config.embedding_provider)

Expand Down
1 change: 0 additions & 1 deletion mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ async def _main(self):

await session_context.llm_api_handler.initialize_client()

print("Scanning codebase for updates...")
await code_context.refresh_daemon()

check_model()
Expand Down

0 comments on commit e51fdfb

Please sign in to comment.