Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Only refresh code context once a loop (#527)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakethekoenig authored Feb 20, 2024
1 parent 66e0bdd commit a2d109f
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 12 deletions.
7 changes: 1 addition & 6 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def __init__(

def refresh_context_display(self):
"""
Sends a message to the client with the new updated code context
Must be called whenever the context changes!
Sends a message to the client with the code context. It is called in the main loop.
"""
ctx = SESSION_CONTEXT.get()

Expand Down Expand Up @@ -194,7 +193,6 @@ async def get_code_message(
self.auto_features = list(
set(self.auto_features) | set(await feature_filter.filter(features))
)
self.refresh_context_display()

# Merge include file features and auto features and add to code message
code_message += get_code_message_from_features(
Expand Down Expand Up @@ -244,7 +242,6 @@ def clear_auto_context(self):
Clears all auto-features added to the conversation so far.
"""
self._auto_features = []
self.refresh_context_display()

def include_features(self, code_features: Iterable[CodeFeature]):
"""
Expand Down Expand Up @@ -272,7 +269,6 @@ def include_features(self, code_features: Iterable[CodeFeature]):
self.include_files[code_feature.path] = []
self.include_files[code_feature.path].append(code_feature)
included_paths.add(Path(str(code_feature)))
self.refresh_context_display()
return included_paths

def include(
Expand Down Expand Up @@ -428,7 +424,6 @@ def exclude(self, path: Path | str) -> Set[Path]:
except PathValidationError as e:
session_context.stream.send(str(e), style="error")

self.refresh_context_display()
return excluded_paths

async def search(
Expand Down
1 change: 0 additions & 1 deletion mentat/cost_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def log_api_call_stats(
costs_logger = logging.getLogger("costs")
costs_logger.info(speed_and_cost_string)
self.last_api_call = speed_and_cost_string
session_context.code_context.refresh_context_display()

def display_last_api_call(self):
"""
Expand Down
1 change: 0 additions & 1 deletion mentat/diff_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def refresh_diff_files(self):
(session_context.cwd / f).resolve()
for f in get_files_in_diff(self.target)
]
session_context.code_context.refresh_context_display()

def get_annotations(self, rel_path: Path) -> list[DiffAnnotation]:
diff = get_diff_for_file(self.target, rel_path)
Expand Down
2 changes: 1 addition & 1 deletion mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,12 @@ async def _main(self):
ensure_ctags_installed()

session_context.llm_api_handler.initialize_client()
code_context.refresh_context_display()
await conversation.display_token_count()

stream.send("Type 'q' or use Ctrl-C to quit at any time.")
need_user_request = True
while True:
code_context.refresh_context_display()
try:
if need_user_request:
# Normally, the code_file_manager pushes the edits; but when agent mode is on, we want all
Expand Down
6 changes: 3 additions & 3 deletions mentat/session_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ async def ask_yes_no(default_yes: bool) -> bool:


async def collect_input_with_commands() -> StreamMessage:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
ctx = SESSION_CONTEXT.get()

response = await collect_user_input(command_autocomplete=True)
while isinstance(response.data, str) and response.data.startswith("/"):
Expand All @@ -61,8 +60,9 @@ async def collect_input_with_commands() -> StreamMessage:
arguments = shlex.split(" ".join(response.data.split(" ")[1:]))
command = Command.create_command(response.data[1:].split(" ")[0])
await command.apply(*arguments)
ctx.code_context.refresh_context_display()
except ValueError as e:
stream.send(f"Error processing command arguments: {e}", style="error")
ctx.stream.send(f"Error processing command arguments: {e}", style="error")
response = await collect_user_input(command_autocomplete=True)
return response

Expand Down

0 comments on commit a2d109f

Please sign in to comment.