diff --git a/frontend/src/api/open-hands.ts b/frontend/src/api/open-hands.ts index 92da6c4d2956..0f8771a161b3 100644 --- a/frontend/src/api/open-hands.ts +++ b/frontend/src/api/open-hands.ts @@ -229,6 +229,7 @@ class OpenHands { ): Promise { const body = { selected_repository: selectedRepository, + selected_branch: undefined, initial_user_msg: initialUserMsg, image_urls: imageUrls, }; diff --git a/openhands/runtime/base.py b/openhands/runtime/base.py index b7c93eea6df2..4f1e37f471dc 100644 --- a/openhands/runtime/base.py +++ b/openhands/runtime/base.py @@ -249,20 +249,37 @@ async def _handle_action(self, event: Action) -> None: source = event.source if event.source else EventSource.AGENT self.event_stream.add_event(observation, source) # type: ignore[arg-type] - def clone_repo(self, github_token: SecretStr, selected_repository: str) -> str: + def clone_repo( + self, + github_token: SecretStr, + selected_repository: str, + selected_branch: str | None, + ) -> str: if not github_token or not selected_repository: raise ValueError( 'github_token and selected_repository must be provided to clone a repository' ) url = f'https://{github_token.get_secret_value()}@github.com/{selected_repository}.git' dir_name = selected_repository.split('/')[1] - # add random branch name to avoid conflicts + + # Generate a random branch name to avoid conflicts random_str = ''.join( random.choices(string.ascii_lowercase + string.digits, k=8) ) - branch_name = f'openhands-workspace-{random_str}' + openhands_workspace_branch = f'openhands-workspace-{random_str}' + + # Clone repository command + clone_command = f'git clone {url} {dir_name}' + + # Checkout to appropriate branch + checkout_command = ( + f'git checkout {selected_branch}' + if selected_branch + else f'git checkout -b {openhands_workspace_branch}' + ) + action = CmdRunAction( - command=f'git clone {url} {dir_name} ; cd {dir_name} ; git checkout -b {branch_name}', + command=f'{clone_command} ; cd {dir_name} ; {checkout_command}', ) self.log('info', f'Cloning repo: {selected_repository}') self.run_action(action) diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index 4edfb47c5177..29db83007656 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -38,6 +38,7 @@ class InitSessionRequest(BaseModel): selected_repository: str | None = None + selected_branch: str | None = None initial_user_msg: str | None = None image_urls: list[str] | None = None @@ -46,6 +47,7 @@ async def _create_new_conversation( user_id: str | None, token: SecretStr | None, selected_repository: str | None, + selected_branch: str | None, initial_user_msg: str | None, image_urls: list[str] | None, ): @@ -74,6 +76,7 @@ async def _create_new_conversation( session_init_args['github_token'] = token or SecretStr('') session_init_args['selected_repository'] = selected_repository + session_init_args['selected_branch'] = selected_branch conversation_init_data = ConversationInitData(**session_init_args) logger.info('Loading conversation store') conversation_store = await ConversationStoreImpl.get_instance(config, user_id) @@ -135,6 +138,7 @@ async def new_conversation(request: Request, data: InitSessionRequest): github_token = await gh_client.get_latest_token() selected_repository = data.selected_repository + selected_branch = data.selected_branch initial_user_msg = data.initial_user_msg image_urls = data.image_urls or [] @@ -144,6 +148,7 @@ async def new_conversation(request: Request, data: InitSessionRequest): user_id, github_token, selected_repository, + selected_branch, initial_user_msg, image_urls, ) diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 298474b884a2..31a31bd151fe 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -76,6 +76,7 @@ async def start( agent_configs: dict[str, AgentConfig] | None = None, github_token: SecretStr | None = None, selected_repository: str | None = None, + selected_branch: str | None = None, initial_message: MessageAction | None = None, ): """Starts the Agent session @@ -105,6 +106,7 @@ async def start( agent=agent, github_token=github_token, selected_repository=selected_repository, + selected_branch=selected_branch, ) self.controller = self._create_controller( @@ -184,6 +186,7 @@ async def _create_runtime( agent: Agent, github_token: SecretStr | None = None, selected_repository: str | None = None, + selected_branch: str | None = None, ): """Creates a runtime instance @@ -239,7 +242,10 @@ async def _create_runtime( repo_directory = None if selected_repository: repo_directory = await call_sync_from_async( - self.runtime.clone_repo, github_token, selected_repository + self.runtime.clone_repo, + github_token, + selected_repository, + selected_branch, ) if agent.prompt_manager: diff --git a/openhands/server/session/conversation_init_data.py b/openhands/server/session/conversation_init_data.py index 8773f48b326a..4cb6acd50f22 100644 --- a/openhands/server/session/conversation_init_data.py +++ b/openhands/server/session/conversation_init_data.py @@ -10,3 +10,4 @@ class ConversationInitData(Settings): github_token: SecretStr | None = Field(default=None) selected_repository: str | None = Field(default=None) + selected_branch: str | None = Field(default=None) diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index 5d34baf4f6e5..d7807fc94740 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -123,9 +123,11 @@ async def initialize_agent( github_token = None selected_repository = None + selected_branch = None if isinstance(settings, ConversationInitData): github_token = settings.github_token selected_repository = settings.selected_repository + selected_branch = settings.selected_branch try: await self.agent_session.start( @@ -138,6 +140,7 @@ async def initialize_agent( agent_configs=self.config.get_agent_configs(), github_token=github_token, selected_repository=selected_repository, + selected_branch=selected_branch, initial_message=initial_message, ) except Exception as e: