Skip to content

Commit

Permalink
Add GitHub repository information to system prompt (#6057)
Browse files Browse the repository at this point in the history
Co-authored-by: openhands <[email protected]>
  • Loading branch information
rbren and openhands-agent authored Jan 15, 2025
1 parent 3d9b4c4 commit fa6792e
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 12 deletions.
9 changes: 7 additions & 2 deletions openhands/agenthub/codeact_agent/prompts/system_prompt.j2
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ You are OpenHands agent, a helpful AI assistant that can interact with a compute
* The assistant MUST NOT include comments in the code unless they are necessary to describe non-obvious behavior.
{{ runtime_info }}
</IMPORTANT>
{% if repo_instructions -%}
{% if repository_info %}
<REPOSITORY_INFO>
At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}.
</REPOSITORY_INFO>
{% endif %}
{% if repository_instructions -%}
<REPOSITORY_INSTRUCTIONS>
{{ repo_instructions }}
{{ repository_instructions }}
</REPOSITORY_INSTRUCTIONS>
{% endif %}
{% if runtime_info and runtime_info.available_hosts -%}
Expand Down
3 changes: 2 additions & 1 deletion openhands/runtime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ async def _handle_action(self, event: Action) -> None:
source = event.source if event.source else EventSource.AGENT
self.event_stream.add_event(observation, source) # type: ignore[arg-type]

def clone_repo(self, github_token: str, selected_repository: str):
def clone_repo(self, github_token: str, selected_repository: str) -> str:
if not github_token or not selected_repository:
raise ValueError(
'github_token and selected_repository must be provided to clone a repository'
Expand All @@ -223,6 +223,7 @@ def clone_repo(self, github_token: str, selected_repository: str):
)
self.log('info', f'Cloning repo: {selected_repository}')
self.run_action(action)
return dir_name

def get_microagents_from_selected_repo(
self, selected_repository: str | None
Expand Down
7 changes: 6 additions & 1 deletion openhands/server/session/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,9 @@ async def _create_runtime(
)
return

repo_directory = None
if selected_repository:
await call_sync_from_async(
repo_directory = await call_sync_from_async(
self.runtime.clone_repo, github_token, selected_repository
)

Expand All @@ -223,6 +224,10 @@ async def _create_runtime(
self.runtime.get_microagents_from_selected_repo, selected_repository
)
agent.prompt_manager.load_microagents(microagents)
if selected_repository and repo_directory:
agent.prompt_manager.set_repository_info(
selected_repository, repo_directory
)

logger.debug(
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
Expand Down
36 changes: 31 additions & 5 deletions openhands/utils/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ class RuntimeInfo:
available_hosts: dict[str, int]


@dataclass
class RepositoryInfo:
"""Information about a GitHub repository that has been cloned."""

repo_name: str | None = None
repo_directory: str | None = None


class PromptManager:
"""
Manages prompt templates and micro-agents for AI interactions.
Expand All @@ -42,7 +50,7 @@ def __init__(
):
self.disabled_microagents: list[str] = disabled_microagents or []
self.prompt_dir: str = prompt_dir

self.repository_info: RepositoryInfo | None = None
self.system_template: Template = self._load_template('system_prompt')
self.user_template: Template = self._load_template('user_prompt')
self.runtime_info = RuntimeInfo(available_hosts={})
Expand Down Expand Up @@ -80,9 +88,6 @@ def load_microagents(self, microagents: list[BaseMicroAgent]):
elif isinstance(microagent, RepoMicroAgent):
self.repo_microagents[microagent.name] = microagent

def set_runtime_info(self, runtime: Runtime):
self.runtime_info.available_hosts = runtime.web_hosts

def _load_template(self, template_name: str) -> Template:
if self.prompt_dir is None:
raise ValueError('Prompt directory is not set')
Expand All @@ -102,10 +107,31 @@ def get_system_message(self) -> str:
if repo_instructions:
repo_instructions += '\n\n'
repo_instructions += microagent.content

return self.system_template.render(
runtime_info=self.runtime_info, repo_instructions=repo_instructions
repository_instructions=repo_instructions,
repository_info=self.repository_info,
runtime_info=self.runtime_info,
).strip()

def set_runtime_info(self, runtime: Runtime):
self.runtime_info.available_hosts = runtime.web_hosts

def set_repository_info(
self,
repo_name: str,
repo_directory: str,
) -> None:
"""Sets information about the GitHub repository that has been cloned.
Args:
repo_name: The name of the GitHub repository (e.g. 'owner/repo')
repo_directory: The directory where the repository has been cloned
"""
self.repository_info = RepositoryInfo(
repo_name=repo_name, repo_directory=repo_directory
)

def get_example_user_message(self) -> str:
"""This is the initial user message provided to the agent
before *actual* user instructions are provided.
Expand Down
51 changes: 48 additions & 3 deletions tests/unit/test_prompt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from openhands.core.message import Message, TextContent
from openhands.microagent import BaseMicroAgent
from openhands.utils.prompt import PromptManager
from openhands.utils.prompt import PromptManager, RepositoryInfo


@pytest.fixture
Expand Down Expand Up @@ -39,6 +39,7 @@ def test_prompt_manager_with_microagent(prompt_dir):
with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f:
f.write(microagent_content)

# Test without GitHub repo
manager = PromptManager(
prompt_dir=prompt_dir,
microagent_dir=os.path.join(prompt_dir, 'micro'),
Expand All @@ -53,6 +54,14 @@ def test_prompt_manager_with_microagent(prompt_dir):
'You are OpenHands agent, a helpful AI assistant that can interact with a computer to solve tasks.'
in manager.get_system_message()
)
assert '<REPOSITORY_INFO>' not in manager.get_system_message()

# Test with GitHub repo
manager.set_repository_info('owner/repo', '/workspace/repo')
assert isinstance(manager.get_system_message(), str)
assert '<REPOSITORY_INFO>' in manager.get_system_message()
assert 'owner/repo' in manager.get_system_message()
assert '/workspace/repo' in manager.get_system_message()

assert isinstance(manager.get_example_user_message(), str)

Expand All @@ -76,20 +85,56 @@ def test_prompt_manager_file_not_found(prompt_dir):
def test_prompt_manager_template_rendering(prompt_dir):
# Create temporary template files
with open(os.path.join(prompt_dir, 'system_prompt.j2'), 'w') as f:
f.write('System prompt: bar')
f.write("""System prompt: bar
{% if repository_info %}
<REPOSITORY_INFO>
At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}.
</REPOSITORY_INFO>
{% endif %}
{{ repo_instructions }}""")
with open(os.path.join(prompt_dir, 'user_prompt.j2'), 'w') as f:
f.write('User prompt: foo')

# Test without GitHub repo
manager = PromptManager(prompt_dir, microagent_dir='')

assert manager.get_system_message() == 'System prompt: bar'
assert manager.get_example_user_message() == 'User prompt: foo'

# Test with GitHub repo
manager = PromptManager(prompt_dir=prompt_dir, microagent_dir='')
manager.set_repository_info('owner/repo', '/workspace/repo')
assert manager.repository_info.repo_name == 'owner/repo'
system_msg = manager.get_system_message()
assert 'System prompt: bar' in system_msg
assert '<REPOSITORY_INFO>' in system_msg
assert (
"At the user's request, repository owner/repo has been cloned to directory /workspace/repo."
in system_msg
)
assert '</REPOSITORY_INFO>' in system_msg
assert manager.get_example_user_message() == 'User prompt: foo'

# Clean up temporary files
os.remove(os.path.join(prompt_dir, 'system_prompt.j2'))
os.remove(os.path.join(prompt_dir, 'user_prompt.j2'))


def test_prompt_manager_repository_info(prompt_dir):
# Test RepositoryInfo defaults
repo_info = RepositoryInfo()
assert repo_info.repo_name is None
assert repo_info.repo_directory is None

# Test setting repository info
manager = PromptManager(prompt_dir=prompt_dir, microagent_dir='')
assert manager.repository_info is None

# Test setting repository info with both name and directory
manager.set_repository_info('owner/repo2', '/workspace/repo2')
assert manager.repository_info.repo_name == 'owner/repo2'
assert manager.repository_info.repo_directory == '/workspace/repo2'


def test_prompt_manager_disabled_microagents(prompt_dir):
# Create test microagent files
microagent1_name = 'test_microagent1'
Expand Down

0 comments on commit fa6792e

Please sign in to comment.