From a2d109fe067308336e19f481c96b9d14afd266d8 Mon Sep 17 00:00:00 2001 From: Jake Koenig Date: Tue, 20 Feb 2024 11:38:29 -0800 Subject: [PATCH] Only refresh code context once a loop (#527) --- mentat/code_context.py | 7 +------ mentat/cost_tracker.py | 1 - mentat/diff_context.py | 1 - mentat/session.py | 2 +- mentat/session_input.py | 6 +++--- 5 files changed, 5 insertions(+), 12 deletions(-) diff --git a/mentat/code_context.py b/mentat/code_context.py index abfdfd4fd..8da2efb8f 100644 --- a/mentat/code_context.py +++ b/mentat/code_context.py @@ -75,8 +75,7 @@ def __init__( def refresh_context_display(self): """ - Sends a message to the client with the new updated code context - Must be called whenever the context changes! + Sends a message to the client with the code context. It is called in the main loop. """ ctx = SESSION_CONTEXT.get() @@ -194,7 +193,6 @@ async def get_code_message( self.auto_features = list( set(self.auto_features) | set(await feature_filter.filter(features)) ) - self.refresh_context_display() # Merge include file features and auto features and add to code message code_message += get_code_message_from_features( @@ -244,7 +242,6 @@ def clear_auto_context(self): Clears all auto-features added to the conversation so far. """ self._auto_features = [] - self.refresh_context_display() def include_features(self, code_features: Iterable[CodeFeature]): """ @@ -272,7 +269,6 @@ def include_features(self, code_features: Iterable[CodeFeature]): self.include_files[code_feature.path] = [] self.include_files[code_feature.path].append(code_feature) included_paths.add(Path(str(code_feature))) - self.refresh_context_display() return included_paths def include( @@ -428,7 +424,6 @@ def exclude(self, path: Path | str) -> Set[Path]: except PathValidationError as e: session_context.stream.send(str(e), style="error") - self.refresh_context_display() return excluded_paths async def search( diff --git a/mentat/cost_tracker.py b/mentat/cost_tracker.py index 8a5ced2c8..159799f63 100644 --- a/mentat/cost_tracker.py +++ b/mentat/cost_tracker.py @@ -50,7 +50,6 @@ def log_api_call_stats( costs_logger = logging.getLogger("costs") costs_logger.info(speed_and_cost_string) self.last_api_call = speed_and_cost_string - session_context.code_context.refresh_context_display() def display_last_api_call(self): """ diff --git a/mentat/diff_context.py b/mentat/diff_context.py index 75b09d366..f024adea8 100644 --- a/mentat/diff_context.py +++ b/mentat/diff_context.py @@ -169,7 +169,6 @@ def refresh_diff_files(self): (session_context.cwd / f).resolve() for f in get_files_in_diff(self.target) ] - session_context.code_context.refresh_context_display() def get_annotations(self, rel_path: Path) -> list[DiffAnnotation]: diff = get_diff_for_file(self.target, rel_path) diff --git a/mentat/session.py b/mentat/session.py index 607c8bb25..a95e1e401 100644 --- a/mentat/session.py +++ b/mentat/session.py @@ -151,12 +151,12 @@ async def _main(self): ensure_ctags_installed() session_context.llm_api_handler.initialize_client() - code_context.refresh_context_display() await conversation.display_token_count() stream.send("Type 'q' or use Ctrl-C to quit at any time.") need_user_request = True while True: + code_context.refresh_context_display() try: if need_user_request: # Normally, the code_file_manager pushes the edits; but when agent mode is on, we want all diff --git a/mentat/session_input.py b/mentat/session_input.py index ad0c0e97f..b546c67c3 100644 --- a/mentat/session_input.py +++ b/mentat/session_input.py @@ -51,8 +51,7 @@ async def ask_yes_no(default_yes: bool) -> bool: async def collect_input_with_commands() -> StreamMessage: - session_context = SESSION_CONTEXT.get() - stream = session_context.stream + ctx = SESSION_CONTEXT.get() response = await collect_user_input(command_autocomplete=True) while isinstance(response.data, str) and response.data.startswith("/"): @@ -61,8 +60,9 @@ async def collect_input_with_commands() -> StreamMessage: arguments = shlex.split(" ".join(response.data.split(" ")[1:])) command = Command.create_command(response.data[1:].split(" ")[0]) await command.apply(*arguments) + ctx.code_context.refresh_context_display() except ValueError as e: - stream.send(f"Error processing command arguments: {e}", style="error") + ctx.stream.send(f"Error processing command arguments: {e}", style="error") response = await collect_user_input(command_autocomplete=True) return response