Skip to content

Commit

Permalink
Optimize memory usage in FileEditObservation (#6622)
Browse files Browse the repository at this point in the history
Co-authored-by: openhands <[email protected]>
Co-authored-by: Xingyao Wang <[email protected]>
  • Loading branch information
3 people authored Feb 7, 2025
1 parent ff48f8b commit 93d2e4a
Show file tree
Hide file tree
Showing 4 changed files with 324 additions and 32 deletions.
101 changes: 70 additions & 31 deletions openhands/events/observation/files.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""File-related observation classes for tracking file operations."""

from dataclasses import dataclass
from difflib import SequenceMatcher

Expand All @@ -16,79 +18,100 @@ class FileReadObservation(Observation):

@property
def message(self) -> str:
"""Get a human-readable message describing the file read operation."""
return f'I read the file {self.path}.'

def __str__(self) -> str:
return f'[Read from {self.path} is successful.]\n' f'{self.content}'
"""Get a string representation of the file read observation."""
return f'[Read from {self.path} is successful.]\n{self.content}'


@dataclass
class FileWriteObservation(Observation):
"""This data class represents a file write operation"""
"""This data class represents a file write operation."""

path: str
observation: str = ObservationType.WRITE

@property
def message(self) -> str:
"""Get a human-readable message describing the file write operation."""
return f'I wrote to the file {self.path}.'

def __str__(self) -> str:
return f'[Write to {self.path} is successful.]\n' f'{self.content}'
"""Get a string representation of the file write observation."""
return f'[Write to {self.path} is successful.]\n{self.content}'


@dataclass
class FileEditObservation(Observation):
"""This data class represents a file edit operation"""
"""This data class represents a file edit operation.
The observation includes both the old and new content of the file, and can
generate a diff visualization showing the changes. The diff is computed lazily
and cached to improve performance.
"""

# content: str will be a unified diff patch string include NO context lines
path: str
prev_exist: bool
old_content: str
new_content: str
observation: str = ObservationType.EDIT
impl_source: FileEditSource = FileEditSource.LLM_BASED_EDIT
formatted_output_and_error: str = ''
_diff_cache: str | None = None # Cache for the diff visualization

@property
def message(self) -> str:
"""Get a human-readable message describing the file edit operation."""
return f'I edited the file {self.path}.'

def get_edit_groups(self, n_context_lines: int = 2) -> list[dict[str, list[str]]]:
"""Get the edit groups of the file edit."""
"""Get the edit groups showing changes between old and new content.
Args:
n_context_lines: Number of context lines to show around each change.
Returns:
A list of edit groups, where each group contains before/after edits.
"""
old_lines = self.old_content.split('\n')
new_lines = self.new_content.split('\n')
# Borrowed from difflib.unified_diff to directly parse into structured format.
# Borrowed from difflib.unified_diff to directly parse into structured format
edit_groups: list[dict] = []
for group in SequenceMatcher(None, old_lines, new_lines).get_grouped_opcodes(
n_context_lines
):
# take the max line number in the group
_indent_pad_size = len(str(group[-1][3])) + 1 # +1 for the "*" prefix
# Take the max line number in the group
_indent_pad_size = len(str(group[-1][3])) + 1 # +1 for "*" prefix
cur_group: dict[str, list[str]] = {
'before_edits': [],
'after_edits': [],
}
for tag, i1, i2, j1, j2 in group:
if tag == 'equal':
for idx, line in enumerate(old_lines[i1:i2]):
line_num = i1 + idx + 1
cur_group['before_edits'].append(
f'{i1+idx+1:>{_indent_pad_size}}|{line}'
f'{line_num:>{_indent_pad_size}}|{line}'
)
for idx, line in enumerate(new_lines[j1:j2]):
line_num = j1 + idx + 1
cur_group['after_edits'].append(
f'{j1+idx+1:>{_indent_pad_size}}|{line}'
f'{line_num:>{_indent_pad_size}}|{line}'
)
continue
if tag in {'replace', 'delete'}:
for idx, line in enumerate(old_lines[i1:i2]):
line_num = i1 + idx + 1
cur_group['before_edits'].append(
f'-{i1+idx+1:>{_indent_pad_size-1}}|{line}'
f'-{line_num:>{_indent_pad_size-1}}|{line}'
)
if tag in {'replace', 'insert'}:
for idx, line in enumerate(new_lines[j1:j2]):
line_num = j1 + idx + 1
cur_group['after_edits'].append(
f'+{j1+idx+1:>{_indent_pad_size-1}}|{line}'
f'+{line_num:>{_indent_pad_size-1}}|{line}'
)
edit_groups.append(cur_group)
return edit_groups
Expand All @@ -100,24 +123,37 @@ def visualize_diff(
) -> str:
"""Visualize the diff of the file edit.
Instead of showing the diff line by line, this function
shows each hunk of changes as a separate entity.
Instead of showing the diff line by line, this function shows each hunk
of changes as a separate entity.
Args:
n_context_lines: The number of lines of context to show before and after the changes.
change_applied: Whether the changes are applied to the file. If true, the file have been modified. If not, the file is not modified (due to linting errors).
n_context_lines: Number of context lines to show before/after changes.
change_applied: Whether changes are applied. If false, shows as
attempted edit.
Returns:
A string containing the formatted diff visualization.
"""
if change_applied and self.content.strip() == '':
# diff patch is empty
return '(no changes detected. Please make sure your edits changes the content of the existing file.)\n'
# Use cached diff if available
if self._diff_cache is not None:
return self._diff_cache

# Check if there are any changes
if change_applied and self.old_content == self.new_content:
msg = '(no changes detected. Please make sure your edits change '
msg += 'the content of the existing file.)\n'
self._diff_cache = msg
return self._diff_cache

edit_groups = self.get_edit_groups(n_context_lines=n_context_lines)

result = [
f'[Existing file {self.path} is edited with {len(edit_groups)} changes.]'
if change_applied
else f"[Changes are NOT applied to {self.path} - Here's how the file looks like if changes are applied.]"
]
if change_applied:
header = f'[Existing file {self.path} is edited with '
header += f'{len(edit_groups)} changes.]'
else:
header = f"[Changes are NOT applied to {self.path} - Here's how "
header += 'the file looks like if changes are applied.]'
result = [header]

op_type = 'edit' if change_applied else 'ATTEMPTED edit'
for i, cur_edit_group in enumerate(edit_groups):
Expand All @@ -129,18 +165,21 @@ def visualize_diff(
result.append(f'(content after {op_type})')
result.extend(cur_edit_group['after_edits'])
result.append(f'[end of {op_type} {i+1} / {len(edit_groups)}]')
return '\n'.join(result)

# Cache the result
self._diff_cache = '\n'.join(result)
return self._diff_cache

def __str__(self) -> str:
"""Get a string representation of the file edit observation."""
if self.impl_source == FileEditSource.OH_ACI:
return self.formatted_output_and_error

ret = ''
if not self.prev_exist:
assert (
self.old_content == ''
), 'old_content should be empty if the file is new (prev_exist=False).'
ret += f'[New file {self.path} is created with the provided content.]\n'
return ret.rstrip() + '\n'
ret += self.visualize_diff()
return ret.rstrip() + '\n'
return f'[New file {self.path} is created with the provided content.]\n'

# Use cached diff if available, otherwise compute it
return self.visualize_diff().rstrip() + '\n'
4 changes: 4 additions & 0 deletions openhands/events/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ def close(self):
for callback_id in callback_ids:
self._clean_up_subscriber(subscriber_id, callback_id)

# Clear queue
while not self._queue.empty():
self._queue.get()

def _clean_up_subscriber(self, subscriber_id: str, callback_id: str):
if subscriber_id not in self._subscribers:
logger.warning(f'Subscriber not found during cleanup: {subscriber_id}')
Expand Down
116 changes: 115 additions & 1 deletion tests/unit/test_event_stream.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
import gc
import json
import os

import psutil
import pytest
from pytest import TempPathFactory

from openhands.core.schema.observation import ObservationType
from openhands.core.schema import ActionType, ObservationType
from openhands.events import EventSource, EventStream
from openhands.events.action import (
NullAction,
)
from openhands.events.action.files import (
FileEditAction,
FileReadAction,
FileWriteAction,
)
from openhands.events.action.message import MessageAction
from openhands.events.event import FileEditSource, FileReadSource
from openhands.events.observation import NullObservation
from openhands.events.observation.files import (
FileEditObservation,
FileReadObservation,
FileWriteObservation,
)
from openhands.storage import get_file_store


Expand Down Expand Up @@ -185,3 +199,103 @@ def test_get_matching_events_limit_validation(temp_dir: str):
assert len(events) == 1
events = event_stream.get_matching_events(limit=100)
assert len(events) == 1


def test_memory_usage_file_operations(temp_dir: str):
"""Test memory usage during file operations in EventStream.
This test verifies that memory usage during file operations is reasonable
and that memory is properly cleaned up after operations complete.
"""

def get_memory_mb():
"""Get current memory usage in MB"""
process = psutil.Process(os.getpid())
return process.memory_info().rss / 1024 / 1024

# Create a test file with 100kb content
test_file = os.path.join(temp_dir, 'test_file.txt')
test_content = 'x' * (100 * 1024) # 100kb of data
with open(test_file, 'w') as f:
f.write(test_content)

# Initialize FileStore and EventStream
file_store = get_file_store('local', temp_dir)

# Record initial memory usage
gc.collect()
initial_memory = get_memory_mb()
max_memory_increase = 0

# Perform operations 20 times
for i in range(20):
event_stream = EventStream('test_session', file_store)

# 1. Read file
read_action = FileReadAction(
path=test_file,
start=0,
end=-1,
thought='Reading file',
action=ActionType.READ,
impl_source=FileReadSource.DEFAULT,
)
event_stream.add_event(read_action, EventSource.AGENT)

read_obs = FileReadObservation(
path=test_file, impl_source=FileReadSource.DEFAULT, content=test_content
)
event_stream.add_event(read_obs, EventSource.ENVIRONMENT)

# 2. Write file
write_action = FileWriteAction(
path=test_file,
content=test_content,
start=0,
end=-1,
thought='Writing file',
action=ActionType.WRITE,
)
event_stream.add_event(write_action, EventSource.AGENT)

write_obs = FileWriteObservation(path=test_file, content=test_content)
event_stream.add_event(write_obs, EventSource.ENVIRONMENT)

# 3. Edit file
edit_action = FileEditAction(
path=test_file,
content=test_content,
start=1,
end=-1,
thought='Editing file',
action=ActionType.EDIT,
impl_source=FileEditSource.LLM_BASED_EDIT,
)
event_stream.add_event(edit_action, EventSource.AGENT)

edit_obs = FileEditObservation(
path=test_file,
prev_exist=True,
old_content=test_content,
new_content=test_content,
impl_source=FileEditSource.LLM_BASED_EDIT,
content=test_content,
)
event_stream.add_event(edit_obs, EventSource.ENVIRONMENT)

# Close event stream and force garbage collection
event_stream.close()
gc.collect()

# Check memory usage
current_memory = get_memory_mb()
memory_increase = current_memory - initial_memory
max_memory_increase = max(max_memory_increase, memory_increase)

# Clean up
os.remove(test_file)

# Memory increase should be reasonable (less than 50MB after 20 iterations)
assert (
max_memory_increase < 50
), f'Memory increase of {max_memory_increase:.1f}MB exceeds limit of 50MB'
Loading

0 comments on commit 93d2e4a

Please sign in to comment.