diff --git a/.github/workflows/py-unit-tests.yml b/.github/workflows/py-unit-tests.yml index f6f68051409f..a7dcfb84acd9 100644 --- a/.github/workflows/py-unit-tests.yml +++ b/.github/workflows/py-unit-tests.yml @@ -48,7 +48,7 @@ jobs: - name: Build Environment run: make build - name: Run Tests - run: poetry run pytest --forked -n auto --cov=openhands --cov-report=xml -svv ./tests/unit --ignore=tests/unit/test_memory.py + run: poetry run pytest --forked -n auto --cov=openhands --cov-report=xml -svv ./tests/unit --ignore=tests/unit/test_long_term_memory.py - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 env: diff --git a/config.template.toml b/config.template.toml index 744dfc7953a4..331663c619f4 100644 --- a/config.template.toml +++ b/config.template.toml @@ -17,6 +17,12 @@ #modal_api_token_id = "" #modal_api_token_secret = "" +# API key for Daytona +#daytona_api_key = "" + +# Daytona Target +#daytona_target = "" + # Base path for the workspace workspace_base = "./workspace" diff --git a/frontend/src/components/features/settings/styled-switch-component.tsx b/frontend/src/components/features/settings/styled-switch-component.tsx index c00a55d18ce5..9299b08b30fb 100644 --- a/frontend/src/components/features/settings/styled-switch-component.tsx +++ b/frontend/src/components/features/settings/styled-switch-component.tsx @@ -12,13 +12,14 @@ export function StyledSwitchComponent({ className={cn( "w-12 h-6 rounded-xl flex items-center p-1.5 cursor-pointer", isToggled && "justify-end bg-primary", - !isToggled && "justify-start bg-[#1F2228] border border-tertiary-alt", + !isToggled && + "justify-start bg-base-secondary border border-tertiary-light", )} >
diff --git a/openhands/core/cli.py b/openhands/core/cli.py index 05a390d8b815..351ced0802bf 100644 --- a/openhands/core/cli.py +++ b/openhands/core/cli.py @@ -100,6 +100,7 @@ async def main(loop: asyncio.AbstractEventLoop): initial_user_action = MessageAction(content=task_str) if task_str else None sid = str(uuid4()) + display_message(f'Session ID: {sid}') runtime = create_runtime(config, sid=sid, headless_mode=True) await runtime.connect() diff --git a/openhands/core/config/app_config.py b/openhands/core/config/app_config.py index 0d736c6c9670..e3e44c6880f6 100644 --- a/openhands/core/config/app_config.py +++ b/openhands/core/config/app_config.py @@ -76,6 +76,9 @@ class AppConfig(BaseModel): file_uploads_restrict_file_types: bool = Field(default=False) file_uploads_allowed_extensions: list[str] = Field(default_factory=lambda: ['.*']) runloop_api_key: SecretStr | None = Field(default=None) + daytona_api_key: SecretStr | None = Field(default=None) + daytona_api_url: str = Field(default='https://app.daytona.io/api') + daytona_target: str = Field(default='us') cli_multiline_input: bool = Field(default=False) conversation_max_age_seconds: int = Field(default=864000) # 10 days in seconds microagents_dir: str = Field( diff --git a/openhands/core/config/config_utils.py b/openhands/core/config/config_utils.py index 44893e119b5a..63f04ebd9086 100644 --- a/openhands/core/config/config_utils.py +++ b/openhands/core/config/config_utils.py @@ -25,14 +25,20 @@ def get_field_info(field: FieldInfo) -> dict[str, Any]: # Note: this only works for UnionTypes with None as one of the types if get_origin(field_type) is UnionType: types = get_args(field_type) - non_none_arg = next((t for t in types if t is not type(None)), None) + non_none_arg = next( + (t for t in types if t is not None and t is not type(None)), None + ) if non_none_arg is not None: field_type = non_none_arg optional = True # type name in a pretty format type_name = ( - field_type.__name__ if hasattr(field_type, '__name__') else str(field_type) + str(field_type) + if field_type is None + else ( + field_type.__name__ if hasattr(field_type, '__name__') else str(field_type) + ) ) # default is always present diff --git a/openhands/core/exceptions.py b/openhands/core/exceptions.py index c80ab15d2bb6..342e0db0e7c5 100644 --- a/openhands/core/exceptions.py +++ b/openhands/core/exceptions.py @@ -10,17 +10,17 @@ class AgentError(Exception): class AgentNoInstructionError(AgentError): - def __init__(self, message='Instruction must be provided'): + def __init__(self, message: str = 'Instruction must be provided') -> None: super().__init__(message) class AgentEventTypeError(AgentError): - def __init__(self, message='Event must be a dictionary'): + def __init__(self, message: str = 'Event must be a dictionary') -> None: super().__init__(message) class AgentAlreadyRegisteredError(AgentError): - def __init__(self, name=None): + def __init__(self, name: str | None = None) -> None: if name is not None: message = f"Agent class already registered under '{name}'" else: @@ -29,7 +29,7 @@ def __init__(self, name=None): class AgentNotRegisteredError(AgentError): - def __init__(self, name=None): + def __init__(self, name: str | None = None) -> None: if name is not None: message = f"No agent class registered under '{name}'" else: @@ -38,7 +38,7 @@ def __init__(self, name=None): class AgentStuckInLoopError(AgentError): - def __init__(self, message='Agent got stuck in a loop'): + def __init__(self, message: str = 'Agent got stuck in a loop') -> None: super().__init__(message) @@ -48,7 +48,7 @@ def __init__(self, message='Agent got stuck in a loop'): class TaskInvalidStateError(Exception): - def __init__(self, state=None): + def __init__(self, state: str | None = None) -> None: if state is not None: message = f'Invalid state {state}' else: @@ -64,45 +64,47 @@ def __init__(self, state=None): # This exception gets sent back to the LLM # It might be malformed JSON class LLMMalformedActionError(Exception): - def __init__(self, message='Malformed response'): + def __init__(self, message: str = 'Malformed response') -> None: self.message = message super().__init__(message) - def __str__(self): + def __str__(self) -> str: return self.message # This exception gets sent back to the LLM # For some reason, the agent did not return an action class LLMNoActionError(Exception): - def __init__(self, message='Agent must return an action'): + def __init__(self, message: str = 'Agent must return an action') -> None: super().__init__(message) # This exception gets sent back to the LLM # The LLM output did not include an action, or the action was not the expected type class LLMResponseError(Exception): - def __init__(self, message='Failed to retrieve action from LLM response'): + def __init__( + self, message: str = 'Failed to retrieve action from LLM response' + ) -> None: super().__init__(message) class UserCancelledError(Exception): - def __init__(self, message='User cancelled the request'): + def __init__(self, message: str = 'User cancelled the request') -> None: super().__init__(message) class OperationCancelled(Exception): """Exception raised when an operation is cancelled (e.g. by a keyboard interrupt).""" - def __init__(self, message='Operation was cancelled'): + def __init__(self, message: str = 'Operation was cancelled') -> None: super().__init__(message) class LLMContextWindowExceedError(RuntimeError): def __init__( self, - message='Conversation history longer than LLM context window limit. Consider turning on enable_history_truncation config to avoid this error', - ): + message: str = 'Conversation history longer than LLM context window limit. Consider turning on enable_history_truncation config to avoid this error', + ) -> None: super().__init__(message) @@ -117,7 +119,7 @@ class FunctionCallConversionError(Exception): This typically happens when there's a malformed message (e.g., missing tags). But not due to LLM output. """ - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(message) @@ -127,14 +129,14 @@ class FunctionCallValidationError(Exception): This typically happens when the LLM outputs unrecognized function call / parameter names / values. """ - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(message) class FunctionCallNotExistsError(Exception): """Exception raised when an LLM call a tool that is not registered.""" - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(message) @@ -191,15 +193,17 @@ class AgentRuntimeNotFoundError(AgentRuntimeUnavailableError): class BrowserInitException(Exception): - def __init__(self, message='Failed to initialize browser environment'): + def __init__( + self, message: str = 'Failed to initialize browser environment' + ) -> None: super().__init__(message) class BrowserUnavailableException(Exception): def __init__( self, - message='Browser environment is not available, please check if has been initialized', - ): + message: str = 'Browser environment is not available, please check if has been initialized', + ) -> None: super().__init__(message) @@ -217,5 +221,5 @@ class MicroAgentError(Exception): class MicroAgentValidationError(MicroAgentError): """Raised when there's a validation error in microagent metadata.""" - def __init__(self, message='Micro agent validation failed'): + def __init__(self, message: str = 'Micro agent validation failed') -> None: super().__init__(message) diff --git a/openhands/core/logger.py b/openhands/core/logger.py index dfed3db0af6d..9820ab564d4f 100644 --- a/openhands/core/logger.py +++ b/openhands/core/logger.py @@ -74,10 +74,11 @@ class StackInfoFilter(logging.Filter): - def filter(self, record): + def filter(self, record: logging.LogRecord) -> bool: if record.levelno >= logging.ERROR: - record.stack_info = True - record.exc_info = True + # LogRecord attributes are dynamically typed + setattr(record, 'stack_info', True) + setattr(record, 'exc_info', sys.exc_info()) return True @@ -107,9 +108,9 @@ def strip_ansi(s: str) -> str: class ColoredFormatter(logging.Formatter): - def format(self, record): - msg_type = record.__dict__.get('msg_type') - event_source = record.__dict__.get('event_source') + def format(self, record: logging.LogRecord) -> str: + msg_type = record.__dict__.get('msg_type', '') + event_source = record.__dict__.get('event_source', '') if event_source: new_msg_type = f'{event_source.upper()}_{msg_type}' if new_msg_type in LOG_COLORS: @@ -136,12 +137,13 @@ def format(self, record): return super().format(new_record) -def _fix_record(record: logging.LogRecord): +def _fix_record(record: logging.LogRecord) -> logging.LogRecord: new_record = copy.copy(record) # The formatter expects non boolean values, and will raise an exception if there is a boolean - so we fix these - if new_record.exc_info is True and not new_record.exc_text: # type: ignore - new_record.exc_info = sys.exc_info() # type: ignore - new_record.stack_info = None # type: ignore + # LogRecord attributes are dynamically typed + if getattr(new_record, 'exc_info', None) is True: + setattr(new_record, 'exc_info', sys.exc_info()) + setattr(new_record, 'stack_info', None) return new_record @@ -158,32 +160,32 @@ class RollingLogger: log_lines: list[str] all_lines: str - def __init__(self, max_lines=10, char_limit=80): + def __init__(self, max_lines: int = 10, char_limit: int = 80) -> None: self.max_lines = max_lines self.char_limit = char_limit self.log_lines = [''] * self.max_lines self.all_lines = '' - def is_enabled(self): + def is_enabled(self) -> bool: return DEBUG and sys.stdout.isatty() - def start(self, message=''): + def start(self, message: str = '') -> None: if message: print(message) self._write('\n' * self.max_lines) self._flush() - def add_line(self, line): + def add_line(self, line: str) -> None: self.log_lines.pop(0) self.log_lines.append(line[: self.char_limit]) self.print_lines() self.all_lines += line + '\n' - def write_immediately(self, line): + def write_immediately(self, line: str) -> None: self._write(line) self._flush() - def print_lines(self): + def print_lines(self) -> None: """Display the last n log_lines in the console (not for file logging). This will create the effect of a rolling display in the console. @@ -192,31 +194,31 @@ def print_lines(self): for line in self.log_lines: self.replace_current_line(line) - def move_back(self, amount=-1): + def move_back(self, amount: int = -1) -> None: r"""'\033[F' moves the cursor up one line.""" if amount == -1: amount = self.max_lines self._write('\033[F' * (self.max_lines)) self._flush() - def replace_current_line(self, line=''): + def replace_current_line(self, line: str = '') -> None: r"""'\033[2K\r' clears the line and moves the cursor to the beginning of the line.""" self._write('\033[2K' + line + '\n') self._flush() - def _write(self, line): + def _write(self, line: str) -> None: if not self.is_enabled(): return sys.stdout.write(line) - def _flush(self): + def _flush(self) -> None: if not self.is_enabled(): return sys.stdout.flush() class SensitiveDataFilter(logging.Filter): - def filter(self, record): + def filter(self, record: logging.LogRecord) -> bool: # Gather sensitive values which should not ever appear in the logs. sensitive_values = [] for key, value in os.environ.items(): @@ -245,6 +247,7 @@ def filter(self, record): 'modal_api_token_secret', 'llm_api_key', 'sandbox_env_github_token', + 'daytona_api_key', ] # add env var names @@ -262,7 +265,9 @@ def filter(self, record): return True -def get_console_handler(log_level: int = logging.INFO, extra_info: str | None = None): +def get_console_handler( + log_level: int = logging.INFO, extra_info: str | None = None +) -> logging.StreamHandler: """Returns a console handler for logging.""" console_handler = logging.StreamHandler() console_handler.setLevel(log_level) @@ -273,7 +278,9 @@ def get_console_handler(log_level: int = logging.INFO, extra_info: str | None = return console_handler -def get_file_handler(log_dir: str, log_level: int = logging.INFO): +def get_file_handler( + log_dir: str, log_level: int = logging.INFO +) -> logging.FileHandler: """Returns a file handler for logging.""" os.makedirs(log_dir, exist_ok=True) timestamp = datetime.now().strftime('%Y-%m-%d') @@ -347,7 +354,13 @@ def log_uncaught_exceptions( class LlmFileHandler(logging.FileHandler): """LLM prompt and response logging.""" - def __init__(self, filename, mode='a', encoding='utf-8', delay=False): + def __init__( + self, + filename: str, + mode: str = 'a', + encoding: str = 'utf-8', + delay: bool = False, + ) -> None: """Initializes an instance of LlmFileHandler. Args: @@ -378,7 +391,7 @@ def __init__(self, filename, mode='a', encoding='utf-8', delay=False): self.baseFilename = os.path.join(self.log_directory, filename) super().__init__(self.baseFilename, mode, encoding, delay) - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: """Emits a log record. Args: @@ -393,7 +406,7 @@ def emit(self, record): self.message_counter += 1 -def _get_llm_file_handler(name: str, log_level: int): +def _get_llm_file_handler(name: str, log_level: int) -> LlmFileHandler: # The 'delay' parameter, when set to True, postpones the opening of the log file # until the first log message is emitted. llm_file_handler = LlmFileHandler(name, delay=True) @@ -402,7 +415,7 @@ def _get_llm_file_handler(name: str, log_level: int): return llm_file_handler -def _setup_llm_logger(name: str, log_level: int): +def _setup_llm_logger(name: str, log_level: int) -> logging.Logger: logger = logging.getLogger(name) logger.propagate = False logger.setLevel(log_level) diff --git a/openhands/core/message.py b/openhands/core/message.py index b508142242fd..73dd300e2894 100644 --- a/openhands/core/message.py +++ b/openhands/core/message.py @@ -15,7 +15,9 @@ class Content(BaseModel): cache_prompt: bool = False @model_serializer - def serialize_model(self): + def serialize_model( + self, + ) -> dict[str, str | dict[str, str]] | list[dict[str, str | dict[str, str]]]: raise NotImplementedError('Subclasses should implement this method.') @@ -24,7 +26,7 @@ class TextContent(Content): text: str @model_serializer - def serialize_model(self): + def serialize_model(self) -> dict[str, str | dict[str, str]]: data: dict[str, str | dict[str, str]] = { 'type': self.type, 'text': self.text, @@ -39,7 +41,7 @@ class ImageContent(Content): image_urls: list[str] @model_serializer - def serialize_model(self): + def serialize_model(self) -> list[dict[str, str | dict[str, str]]]: images: list[dict[str, str | dict[str, str]]] = [] for url in self.image_urls: images.append({'type': self.type, 'image_url': {'url': url}}) @@ -101,15 +103,22 @@ def _list_serializer(self) -> dict: # See discussion here for details: https://github.com/BerriAI/litellm/issues/6422#issuecomment-2438765472 if self.role == 'tool' and item.cache_prompt: role_tool_with_prompt_caching = True - if isinstance(d, dict): - d.pop('cache_control') - elif isinstance(d, list): - for d_item in d: - d_item.pop('cache_control') + if isinstance(item, TextContent): + d.pop('cache_control', None) + elif isinstance(item, ImageContent): + # ImageContent.model_dump() always returns a list + # We know d is a list of dicts for ImageContent + if hasattr(d, '__iter__'): + for d_item in d: + if hasattr(d_item, 'pop'): + d_item.pop('cache_control', None) + if isinstance(item, TextContent): content.append(d) elif isinstance(item, ImageContent) and self.vision_enabled: - content.extend(d) + # ImageContent.model_dump() always returns a list + # We know d is a list for ImageContent + content.extend([d] if isinstance(d, dict) else d) message_dict: dict = {'content': content, 'role': self.role} diff --git a/openhands/core/message_utils.py b/openhands/core/message_utils.py index 1ce4b4f84b81..edb8902f2c4d 100644 --- a/openhands/core/message_utils.py +++ b/openhands/core/message_utils.py @@ -160,7 +160,7 @@ def get_action_message( ) llm_response: ModelResponse = tool_metadata.model_response - assistant_msg = llm_response.choices[0].message + assistant_msg = getattr(llm_response.choices[0], 'message') # Add the LLM message (assistant) that initiated the tool calls # (overwrites any previous message with the same response_id) @@ -168,7 +168,7 @@ def get_action_message( f'Tool calls type: {type(assistant_msg.tool_calls)}, value: {assistant_msg.tool_calls}' ) pending_tool_call_action_messages[llm_response.id] = Message( - role=assistant_msg.role, + role=getattr(assistant_msg, 'role', 'assistant'), # tool call content SHOULD BE a string content=[TextContent(text=assistant_msg.content or '')] if assistant_msg.content is not None @@ -185,7 +185,7 @@ def get_action_message( tool_metadata = action.tool_call_metadata if tool_metadata is not None: # take the response message from the tool call - assistant_msg = tool_metadata.model_response.choices[0].message + assistant_msg = getattr(tool_metadata.model_response.choices[0], 'message') content = assistant_msg.content or '' # save content if any, to thought @@ -197,9 +197,11 @@ def get_action_message( # remove the tool call metadata action.tool_call_metadata = None + if role not in ('user', 'system', 'assistant', 'tool'): + raise ValueError(f'Invalid role: {role}') return [ Message( - role=role, + role=role, # type: ignore[arg-type] content=[TextContent(text=action.thought)], ) ] @@ -208,9 +210,11 @@ def get_action_message( content = [TextContent(text=action.content or '')] if vision_is_active and action.image_urls: content.append(ImageContent(image_urls=action.image_urls)) + if role not in ('user', 'system', 'assistant', 'tool'): + raise ValueError(f'Invalid role: {role}') return [ Message( - role=role, + role=role, # type: ignore[arg-type] content=content, ) ] @@ -218,7 +222,7 @@ def get_action_message( content = [TextContent(text=f'User executed the command:\n{action.command}')] return [ Message( - role='user', + role='user', # Always user for CmdRunAction content=content, ) ] diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 66bc6f99cb09..944e4660b47e 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -8,6 +8,7 @@ import requests from openhands.core.config import LLMConfig +from openhands.utils.ensure_httpx_close import EnsureHttpxClose with warnings.catch_warnings(): warnings.simplefilter('ignore') @@ -230,9 +231,9 @@ def wrapper(*args, **kwargs): # Record start time for latency measurement start_time = time.time() - - # we don't support streaming here, thus we get a ModelResponse - resp: ModelResponse = self._completion_unwrapped(*args, **kwargs) + with EnsureHttpxClose(): + # we don't support streaming here, thus we get a ModelResponse + resp: ModelResponse = self._completion_unwrapped(*args, **kwargs) # Calculate and record latency latency = time.time() - start_time @@ -287,7 +288,11 @@ def wrapper(*args, **kwargs): 'messages': messages, 'response': resp, 'args': args, - 'kwargs': {k: v for k, v in kwargs.items() if k != 'messages'}, + 'kwargs': { + k: v + for k, v in kwargs.items() + if k not in ('messages', 'client') + }, 'timestamp': time.time(), 'cost': cost, } diff --git a/openhands/memory/memory.py b/openhands/memory/memory.py deleted file mode 100644 index 537bc2e3d409..000000000000 --- a/openhands/memory/memory.py +++ /dev/null @@ -1,202 +0,0 @@ -from openhands.core.logger import openhands_logger as logger -from openhands.events.action.agent import RecallAction -from openhands.events.action.message import MessageAction -from openhands.events.event import Event, EventSource -from openhands.events.observation.agent import ( - RecallObservation, -) -from openhands.events.stream import EventStream, EventStreamSubscriber -from openhands.microagent import ( - BaseMicroAgent, - KnowledgeMicroAgent, - RepoMicroAgent, - load_microagents_from_dir, -) -from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo - - -class Memory: - """ - Memory is a component that listens to the EventStream for either user MessageAction (to create - a RecallAction) or a RecallAction (to produce a RecallObservation). - """ - - def __init__( - self, - event_stream: EventStream, - microagents_dir: str, - disabled_microagents: list[str] | None = None, - ): - self.event_stream = event_stream - self.microagents_dir = microagents_dir - self.disabled_microagents = disabled_microagents or [] - # Subscribe to events - self.event_stream.subscribe( - EventStreamSubscriber.MEMORY, - self.on_event, - 'Memory', - ) - # Load global microagents (Knowledge + Repo). - self._load_global_microagents() - - # Additional placeholders to store user workspace microagents if needed - self.repo_microagents: dict[str, RepoMicroAgent] = {} - self.knowledge_microagents: dict[str, KnowledgeMicroAgent] = {} - - # Track whether we've seen the first user message - self._first_user_message_seen = False - - # Store repository / runtime info to send them to the templating later - self.repository_info: RepositoryInfo | None = None - self.runtime_info: RuntimeInfo | None = None - - # TODO: enable_prompt_extensions - - def _load_global_microagents(self) -> None: - """ - Loads microagents from the global microagents_dir. - This is effectively what used to happen in PromptManager. - """ - repo_agents, knowledge_agents, _ = load_microagents_from_dir( - self.microagents_dir - ) - for name, agent in knowledge_agents.items(): - if name in self.disabled_microagents: - continue - if isinstance(agent, KnowledgeMicroAgent): - self.knowledge_microagents[name] = agent - for name, agent in repo_agents.items(): - if name in self.disabled_microagents: - continue - if isinstance(agent, RepoMicroAgent): - self.repo_microagents[name] = agent - - def set_repository_info(self, repo_name: str, repo_directory: str) -> None: - """Store repository info so we can reference it in an observation.""" - self.repository_info = RepositoryInfo(repo_name, repo_directory) - self.prompt_manager.set_repository_info(self.repository_info) - - def set_runtime_info(self, runtime_hosts: dict[str, int]) -> None: - """Store runtime info (web hosts, ports, etc.).""" - # e.g. { '127.0.0.1': 8080 } - self.runtime_info = RuntimeInfo(available_hosts=runtime_hosts) - self.prompt_manager.set_runtime_info(self.runtime_info) - - def on_event(self, event: Event): - """Handle an event from the event stream.""" - if isinstance(event, MessageAction): - if event.source == 'user': - # If this is the first user message, create and add a RecallObservation - # with info about repo and runtime. - if not self._first_user_message_seen: - self._first_user_message_seen = True - self._on_first_user_message(event) - # continue with the next handler, to include microagents if suitable for this user message - self._on_user_message_action(event) - elif isinstance(event, RecallAction): - self._on_recall_action(event) - - def _on_first_user_message(self, event: MessageAction): - """Create and add to the stream a RecallObservation carrying info about repo and runtime.""" - # Build the same text that used to be appended to the first user message - repo_instructions = '' - assert ( - len(self.repo_microagents) <= 1 - ), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}' - for microagent in self.repo_microagents.values(): - # We assume these are the repo instructions - if repo_instructions: - repo_instructions += '\n\n' - repo_instructions += microagent.content - - # Now wrap it in a RecallObservation, rather than altering the user message: - obs = RecallObservation( - content=self.prompt_manager.build_additional_info_text(repo_instructions) - ) - self.event_stream.add_event(obs, EventSource.ENVIRONMENT) - - def _on_user_message_action(self, event: MessageAction): - """Replicates old microagent logic: if a microagent triggers on user text, - we embed it in an block and post a RecallObservation.""" - if event.source != 'user': - return - - # If there's no text, do nothing - user_text = event.content.strip() - if not user_text: - return - # Gather all triggered microagents - microagent_blocks = [] - for name, agent in self.knowledge_microagents.items(): - trigger = agent.match_trigger(user_text) - if trigger: - logger.info("Microagent '%s' triggered by keyword '%s'", name, trigger) - micro_text = ( - f'\n' - f'The following information has been included based on a keyword match for "{trigger}". ' - f"It may or may not be relevant to the user's request.\n\n" - f'{agent.content}\n' - f'' - ) - microagent_blocks.append(micro_text) - - if microagent_blocks: - # Combine all triggered microagents into a single RecallObservation - combined_text = '\n'.join(microagent_blocks) - obs = RecallObservation(content=combined_text) - self.event_stream.add_event( - obs, event.source if event.source else EventSource.ENVIRONMENT - ) - - def _on_recall_action(self, event: RecallAction): - """If a RecallAction explicitly arrives, handle it.""" - assert isinstance(event, RecallAction) - - user_query = event.query.get('keywords', []) - matched_content = self.find_microagent_content(user_query) - obs = RecallObservation(content=matched_content) - self.event_stream.add_event( - obs, event.source if event.source else EventSource.ENVIRONMENT - ) - - def find_microagent_content(self, keywords: list[str]) -> str: - """Replicate the same microagent logic.""" - matched_texts: list[str] = [] - for name, agent in self.knowledge_microagents.items(): - for kw in keywords: - trigger = agent.match_trigger(kw) - if trigger: - logger.info( - "Microagent '%s' triggered by explicit RecallAction keyword '%s'", - name, - trigger, - ) - block = ( - f'\n' - f"(via RecallAction) Included knowledge from microagent '{name}', triggered by '{trigger}'\n\n" - f'{agent.content}\n' - f'' - ) - matched_texts.append(block) - return '\n'.join(matched_texts) - - def load_user_workspace_microagents( - self, user_microagents: list[BaseMicroAgent] - ) -> None: - """ - If you want to load microagents from a user's cloned repo or workspace directory, - call this from agent_session or setup once the workspace is cloned. - """ - logger.info( - 'Loading user workspace microagents: %s', [m.name for m in user_microagents] - ) - for ma in user_microagents: - if ma.name in self.disabled_microagents: - continue - if isinstance(ma, KnowledgeMicroAgent): - self.knowledge_microagents[ma.name] = ma - elif isinstance(ma, RepoMicroAgent): - self.repo_microagents[ma.name] = ma - - def set_prompt_manager(self, prompt_manager: PromptManager): - self.prompt_manager = prompt_manager diff --git a/openhands/runtime/__init__.py b/openhands/runtime/__init__.py index 5ddf881fcfa3..0590d62b7c33 100644 --- a/openhands/runtime/__init__.py +++ b/openhands/runtime/__init__.py @@ -1,4 +1,5 @@ from openhands.core.logger import openhands_logger as logger +from openhands.runtime.impl.daytona.daytona_runtime import DaytonaRuntime from openhands.runtime.impl.docker.docker_runtime import ( DockerRuntime, ) @@ -24,6 +25,8 @@ def get_runtime_cls(name: str): return RunloopRuntime elif name == 'local': return LocalRuntime + elif name == 'daytona': + return DaytonaRuntime else: raise ValueError(f'Runtime {name} not supported') diff --git a/openhands/runtime/builder/docker.py b/openhands/runtime/builder/docker.py index 4c2c13f90f04..041ff622370f 100644 --- a/openhands/runtime/builder/docker.py +++ b/openhands/runtime/builder/docker.py @@ -67,7 +67,7 @@ def build( """ self.docker_client = docker.from_env() version_info = self.docker_client.version() - server_version = version_info.get('Version', '').replace('-', '.') + server_version = version_info.get('Version', '').split('+')[0].replace('-', '.') if tuple(map(int, server_version.split('.'))) < (18, 9): raise AgentRuntimeBuildError( 'Docker server version must be >= 18.09 to use BuildKit' diff --git a/openhands/runtime/impl/daytona/README.md b/openhands/runtime/impl/daytona/README.md new file mode 100644 index 000000000000..dfaa7e3dc02f --- /dev/null +++ b/openhands/runtime/impl/daytona/README.md @@ -0,0 +1,24 @@ +# Daytona Runtime + +[Daytona](https://www.daytona.io/) is a platform that provides a secure and elastic infrastructure for running AI-generated code. It provides all the necessary features for an AI Agent to interact with a codebase. It provides a Daytona SDK with official Python and TypeScript interfaces for interacting with Daytona, enabling you to programmatically manage development environments and execute code. + +## Getting started + +1. Sign in at https://app.daytona.io/ + +1. Generate and copy your API key + +1. Set the following environment variables before running the OpenHands app on your local machine or via a `docker run` command: + +```bash + RUNTIME="daytona" + DAYTONA_API_KEY="" +``` +Optionally, if you don't want your sandboxes to default to the US region, set: + +```bash + DAYTONA_TARGET="eu" +``` + +## Documentation +Read more by visiting our [documentation](https://www.daytona.io/docs/) page. diff --git a/openhands/runtime/impl/daytona/daytona_runtime.py b/openhands/runtime/impl/daytona/daytona_runtime.py new file mode 100644 index 000000000000..437f6eeebf52 --- /dev/null +++ b/openhands/runtime/impl/daytona/daytona_runtime.py @@ -0,0 +1,262 @@ +import json +from typing import Callable + +import tenacity +from daytona_sdk import ( + CreateWorkspaceParams, + Daytona, + DaytonaConfig, + SessionExecuteRequest, + Workspace, +) + +from openhands.core.config.app_config import AppConfig +from openhands.events.stream import EventStream +from openhands.runtime.impl.action_execution.action_execution_client import ( + ActionExecutionClient, +) +from openhands.runtime.plugins.requirement import PluginRequirement +from openhands.runtime.utils.command import get_action_execution_server_startup_command +from openhands.utils.async_utils import call_sync_from_async +from openhands.utils.tenacity_stop import stop_if_should_exit + +WORKSPACE_PREFIX = 'openhands-sandbox-' + + +class DaytonaRuntime(ActionExecutionClient): + """The DaytonaRuntime class is a DockerRuntime that utilizes Daytona workspace as a runtime environment.""" + + _sandbox_port: int = 4444 + _vscode_port: int = 4445 + + def __init__( + self, + config: AppConfig, + event_stream: EventStream, + sid: str = 'default', + plugins: list[PluginRequirement] | None = None, + env_vars: dict[str, str] | None = None, + status_callback: Callable | None = None, + attach_to_existing: bool = False, + headless_mode: bool = True, + ): + assert config.daytona_api_key, 'Daytona API key is required' + + self.config = config + self.sid = sid + self.workspace_id = WORKSPACE_PREFIX + sid + self.workspace: Workspace | None = None + self._vscode_url: str | None = None + + daytona_config = DaytonaConfig( + api_key=config.daytona_api_key.get_secret_value(), + server_url=config.daytona_api_url, + target=config.daytona_target, + ) + self.daytona = Daytona(daytona_config) + + # workspace_base cannot be used because we can't bind mount into a workspace. + if self.config.workspace_base is not None: + self.log( + 'warning', + 'Workspace mounting is not supported in the Daytona runtime.', + ) + + super().__init__( + config, + event_stream, + sid, + plugins, + env_vars, + status_callback, + attach_to_existing, + headless_mode, + ) + + def _get_workspace(self) -> Workspace | None: + try: + workspace = self.daytona.get_current_workspace(self.workspace_id) + self.log( + 'info', f'Attached to existing workspace with id: {self.workspace_id}' + ) + except Exception: + self.log( + 'warning', + f'Failed to attach to existing workspace with id: {self.workspace_id}', + ) + workspace = None + + return workspace + + def _get_creation_env_vars(self) -> dict[str, str]: + env_vars: dict[str, str] = { + 'port': str(self._sandbox_port), + 'PYTHONUNBUFFERED': '1', + 'VSCODE_PORT': str(self._vscode_port), + } + + if self.config.debug: + env_vars['DEBUG'] = 'true' + + return env_vars + + def _create_workspace(self) -> Workspace: + workspace_params = CreateWorkspaceParams( + id=self.workspace_id, + language='python', + image=self.config.sandbox.runtime_container_image, + public=True, + env_vars=self._get_creation_env_vars(), + ) + workspace = self.daytona.create(workspace_params) + return workspace + + def _get_workspace_status(self) -> str: + assert self.workspace is not None, 'Workspace is not initialized' + assert ( + self.workspace.instance.info is not None + ), 'Workspace info is not available' + assert ( + self.workspace.instance.info.provider_metadata is not None + ), 'Provider metadata is not available' + + provider_metadata = json.loads(self.workspace.instance.info.provider_metadata) + return provider_metadata.get('status', 'unknown') + + def _construct_api_url(self, port: int) -> str: + assert self.workspace is not None, 'Workspace is not initialized' + assert ( + self.workspace.instance.info is not None + ), 'Workspace info is not available' + assert ( + self.workspace.instance.info.provider_metadata is not None + ), 'Provider metadata is not available' + + node_domain = json.loads(self.workspace.instance.info.provider_metadata)[ + 'nodeDomain' + ] + return f'https://{port}-{self.workspace.id}.{node_domain}' + + def _get_action_execution_server_host(self) -> str: + return self.api_url + + def _start_action_execution_server(self) -> None: + assert self.workspace is not None, 'Workspace is not initialized' + + self.workspace.process.exec( + f'mkdir -p {self.config.workspace_mount_path_in_sandbox}' + ) + + start_command: list[str] = get_action_execution_server_startup_command( + server_port=self._sandbox_port, + plugins=self.plugins, + app_config=self.config, + override_user_id=1000, + override_username='openhands', + ) + start_command_str: str = ' '.join(start_command) + + self.log( + 'debug', + f'Starting action execution server with command: {start_command_str}', + ) + + exec_session_id = 'action-execution-server' + self.workspace.process.create_session(exec_session_id) + self.workspace.process.execute_session_command( + exec_session_id, + SessionExecuteRequest(command='cd /openhands/code', var_async=True), + ) + + exec_command = self.workspace.process.execute_session_command( + exec_session_id, + SessionExecuteRequest(command=start_command_str, var_async=True), + ) + + self.log('debug', f'exec_command_id: {exec_command.cmd_id}') + + @tenacity.retry( + stop=tenacity.stop_after_delay(120) | stop_if_should_exit(), + wait=tenacity.wait_fixed(1), + reraise=(ConnectionRefusedError,), + ) + def _wait_until_alive(self): + super().check_if_alive() + + async def connect(self): + self.send_status_message('STATUS$STARTING_RUNTIME') + + if self.attach_to_existing: + self.workspace = await call_sync_from_async(self._get_workspace) + + if self.workspace is None: + self.send_status_message('STATUS$PREPARING_CONTAINER') + self.workspace = await call_sync_from_async(self._create_workspace) + self.log('info', f'Created new workspace with id: {self.workspace_id}') + + if self._get_workspace_status() == 'stopped': + self.log('info', 'Starting Daytona workspace...') + await call_sync_from_async(self.workspace.start) + + self.api_url = await call_sync_from_async( + self._construct_api_url, self._sandbox_port + ) + + if not self.attach_to_existing: + await call_sync_from_async(self._start_action_execution_server) + self.log( + 'info', + f'Container started. Action execution server url: {self.api_url}', + ) + + self.log('info', 'Waiting for client to become ready...') + self.send_status_message('STATUS$WAITING_FOR_CLIENT') + await call_sync_from_async(self._wait_until_alive) + + if not self.attach_to_existing: + await call_sync_from_async(self.setup_initial_env) + + self.log( + 'info', + f'Container initialized with plugins: {[plugin.name for plugin in self.plugins]}', + ) + + if not self.attach_to_existing: + self.send_status_message(' ') + self._runtime_initialized = True + + def close(self): + super().close() + + if self.attach_to_existing: + return + + if self.workspace: + self.daytona.remove(self.workspace) + + @property + def vscode_url(self) -> str | None: + if self._vscode_url is not None: # cached value + return self._vscode_url + token = super().get_vscode_token() + if not token: + self.log( + 'warning', 'Failed to get VSCode token while trying to get VSCode URL' + ) + return None + if not self.workspace: + self.log( + 'warning', 'Workspace is not initialized while trying to get VSCode URL' + ) + return None + self._vscode_url = ( + self._construct_api_url(self._vscode_port) + + f'/?tkn={token}&folder={self.config.workspace_mount_path_in_sandbox}' + ) + + self.log( + 'debug', + f'VSCode URL: {self._vscode_url}', + ) + + return self._vscode_url diff --git a/openhands/security/invariant/analyzer.py b/openhands/security/invariant/analyzer.py index f843e9304359..540a9341b822 100644 --- a/openhands/security/invariant/analyzer.py +++ b/openhands/security/invariant/analyzer.py @@ -307,11 +307,17 @@ async def security_risk(self, event: Action) -> ActionSecurityRisk: new_elements = parse_element(self.trace, event) input = [e.model_dump(exclude_none=True) for e in new_elements] # type: ignore [call-overload] self.trace.extend(new_elements) - result, err = self.monitor.check(self.input, input) + check_result = self.monitor.check(self.input, input) self.input.extend(input) risk = ActionSecurityRisk.UNKNOWN - if err: - logger.warning(f'Error checking policy: {err}') + + if isinstance(check_result, tuple): + result, err = check_result + if err: + logger.warning(f'Error checking policy: {err}') + return risk + else: + logger.warning(f'Error checking policy: {check_result}') return risk risk = self.get_risk(result) diff --git a/openhands/security/invariant/client.py b/openhands/security/invariant/client.py index c41828745658..f2ccc78bd61f 100644 --- a/openhands/security/invariant/client.py +++ b/openhands/security/invariant/client.py @@ -50,7 +50,7 @@ def close_session(self) -> Union[None, Exception]: return None class _Policy: - def __init__(self, invariant): + def __init__(self, invariant: 'InvariantClient') -> None: self.server = invariant.server self.session_id = invariant.session_id @@ -77,7 +77,7 @@ def get_template(self) -> tuple[str | None, Exception | None]: except (ConnectionError, Timeout, HTTPError) as err: return None, err - def from_string(self, rule: str): + def from_string(self, rule: str) -> 'InvariantClient._Policy': policy_id, err = self._create_policy(rule) if err: raise err @@ -97,7 +97,7 @@ def analyze(self, trace: list[dict]) -> Union[Any, Exception]: return None, err class _Monitor: - def __init__(self, invariant): + def __init__(self, invariant: 'InvariantClient') -> None: self.server = invariant.server self.session_id = invariant.session_id self.policy = '' @@ -114,7 +114,7 @@ def _create_monitor(self, rule: str) -> tuple[str | None, Exception | None]: except (ConnectionError, Timeout, HTTPError) as err: return None, err - def from_string(self, rule: str): + def from_string(self, rule: str) -> 'InvariantClient._Monitor': monitor_id, err = self._create_monitor(rule) if err: raise err diff --git a/openhands/security/invariant/nodes.py b/openhands/security/invariant/nodes.py index 47410264743b..c3d7b9713bea 100644 --- a/openhands/security/invariant/nodes.py +++ b/openhands/security/invariant/nodes.py @@ -1,3 +1,4 @@ +from typing import Any, Iterable, Tuple from pydantic import BaseModel, Field from pydantic.dataclasses import dataclass @@ -10,7 +11,7 @@ class LLM: class Event(BaseModel): metadata: dict | None = Field( - default_factory=dict, description='Metadata associated with the event' + default_factory=lambda: dict(), description='Metadata associated with the event' ) @@ -30,7 +31,7 @@ class Message(Event): content: str | None tool_calls: list[ToolCall] | None = None - def __rich_repr__(self): + def __rich_repr__(self) -> Iterable[Any | tuple[Any] | tuple[str, Any] | tuple[str, Any, Any]]: # Print on separate line yield 'role', self.role yield 'content', self.content diff --git a/openhands/utils/ensure_httpx_close.py b/openhands/utils/ensure_httpx_close.py new file mode 100644 index 000000000000..f60a7189d565 --- /dev/null +++ b/openhands/utils/ensure_httpx_close.py @@ -0,0 +1,43 @@ +""" +LiteLLM currently have an issue where HttpHandlers are being created but not +closed. We have submitted a PR to them, (https://github.com/BerriAI/litellm/pull/8711) +and their dev team say they are in the process of a refactor that will fix this, but +in the meantime, we need to manage the lifecycle of the httpx.Client manually. + +We can't simply pass in our own client object, because all the different implementations use +different types of client object. + +So we monkey patch the httpx.Client class to track newly created instances and close these +when the operations complete. (This is relatively safe, as if the client is reused after this +then is will transparently reopen) + +Hopefully, this will be fixed soon and we can remove this abomination. +""" + +from dataclasses import dataclass, field +from functools import wraps +from typing import Callable + +from httpx import Client + + +@dataclass +class EnsureHttpxClose: + clients: list[Client] = field(default_factory=list) + original_init: Callable | None = None + + def __enter__(self): + self.original_init = Client.__init__ + + @wraps(Client.__init__) + def init_wrapper(*args, **kwargs): + self.clients.append(args[0]) + return self.original_init(*args, **kwargs) # type: ignore + + Client.__init__ = init_wrapper + + def __exit__(self, type, value, traceback): + Client.__init__ = self.original_init + while self.clients: + client = self.clients.pop() + client.close() diff --git a/poetry.lock b/poetry.lock index b7b7f8f689d9..c0151ab7545e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1507,6 +1507,47 @@ tests-numpy2 = ["Pillow (>=9.4.0)", "absl-py", "decorator", "elasticsearch (<8.0 torch = ["torch"] vision = ["Pillow (>=9.4.0)"] +[[package]] +name = "daytona-api-client" +version = "0.13.0" +description = "Daytona Workspaces" +optional = false +python-versions = "*" +groups = ["main"] +files = [ + {file = "daytona_api_client-0.13.0-py3-none-any.whl", hash = "sha256:c4d0dcb89a328c4d0a97d8f076eaf9a00ccc54a8b9f862f4b3302ae887d03c8f"}, + {file = "daytona_api_client-0.13.0.tar.gz", hash = "sha256:d62b7cb14361b2706df192d2da7dc2b5d02be6fd4259e9433cf2bfdc5807416d"}, +] + +[package.dependencies] +pydantic = ">=2" +python-dateutil = ">=2.8.2" +typing-extensions = ">=4.7.1" +urllib3 = ">=1.25.3,<3.0.0" + +[[package]] +name = "daytona-sdk" +version = "0.9.1" +description = "Python SDK for Daytona" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "daytona_sdk-0.9.1-py3-none-any.whl", hash = "sha256:cce6c90cd3d578747b3c388e24c811cb0b21ad125d34b32836c50059a577a12a"}, + {file = "daytona_sdk-0.9.1.tar.gz", hash = "sha256:1e2f219f55130fc72d2f14a57d008b8d3e236d45294e0ca51e249106be5ca5de"}, +] + +[package.dependencies] +daytona_api_client = ">=0.13.0,<1.0.0" +environs = ">=9.5.0,<10.0.0" +marshmallow = ">=3.19.0,<4.0.0" +pydantic = ">=2.4.2,<3.0.0" +python-dateutil = ">=2.8.2,<3.0.0" +urllib3 = ">=2.0.7,<3.0.0" + +[package.extras] +dev = ["black (>=22.0.0)", "isort (>=5.10.0)", "pydoc-markdown (>=4.8.2)"] + [[package]] name = "debugpy" version = "1.8.12" @@ -1731,6 +1772,28 @@ files = [ {file = "english-words-2.0.1.tar.gz", hash = "sha256:a4105c57493bb757a3d8973fcf8e1dc05e7ca09c836dff467c3fb445f84bc43d"}, ] +[[package]] +name = "environs" +version = "9.5.0" +description = "simplified environment variable parsing" +optional = false +python-versions = ">=3.6" +groups = ["main"] +files = [ + {file = "environs-9.5.0-py2.py3-none-any.whl", hash = "sha256:1e549569a3de49c05f856f40bce86979e7d5ffbbc4398e7f338574c220189124"}, + {file = "environs-9.5.0.tar.gz", hash = "sha256:a76307b36fbe856bdca7ee9161e6c466fd7fcffc297109a118c59b54e27e30c9"}, +] + +[package.dependencies] +marshmallow = ">=3.0.0" +python-dotenv = "*" + +[package.extras] +dev = ["dj-database-url", "dj-email-url", "django-cache-url", "flake8 (==4.0.1)", "flake8-bugbear (==21.9.2)", "mypy (==0.910)", "pre-commit (>=2.4,<3.0)", "pytest", "tox"] +django = ["dj-database-url", "dj-email-url", "django-cache-url"] +lint = ["flake8 (==4.0.1)", "flake8-bugbear (==21.9.2)", "mypy (==0.910)", "pre-commit (>=2.4,<3.0)"] +tests = ["dj-database-url", "dj-email-url", "django-cache-url", "pytest"] + [[package]] name = "evaluate" version = "0.4.3" @@ -3132,14 +3195,14 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "huggingface-hub" -version = "0.28.1" +version = "0.29.0" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" groups = ["main", "evaluation", "llama-index"] files = [ - {file = "huggingface_hub-0.28.1-py3-none-any.whl", hash = "sha256:aa6b9a3ffdae939b72c464dbb0d7f99f56e649b55c3d52406f49e0a5a620c0a7"}, - {file = "huggingface_hub-0.28.1.tar.gz", hash = "sha256:893471090c98e3b6efbdfdacafe4052b20b84d59866fb6f54c33d9af18c303ae"}, + {file = "huggingface_hub-0.29.0-py3-none-any.whl", hash = "sha256:c02daa0b6bafbdacb1320fdfd1dc7151d0940825c88c4ef89837fdb1f6ea0afe"}, + {file = "huggingface_hub-0.29.0.tar.gz", hash = "sha256:64034c852be270cac16c5743fe1f659b14515a9de6342d6f42cbb2ede191fc80"}, ] [package.dependencies] @@ -4004,14 +4067,14 @@ files = [ [[package]] name = "kubernetes" -version = "32.0.0" +version = "32.0.1" description = "Kubernetes python client" optional = false python-versions = ">=3.6" groups = ["llama-index"] files = [ - {file = "kubernetes-32.0.0-py2.py3-none-any.whl", hash = "sha256:60fd8c29e8e43d9c553ca4811895a687426717deba9c0a66fb2dcc3f5ef96692"}, - {file = "kubernetes-32.0.0.tar.gz", hash = "sha256:319fa840345a482001ac5d6062222daeb66ec4d1bcb3087402aed685adf0aecb"}, + {file = "kubernetes-32.0.1-py2.py3-none-any.whl", hash = "sha256:35282ab8493b938b08ab5526c7ce66588232df00ef5e1dbe88a419107dc10998"}, + {file = "kubernetes-32.0.1.tar.gz", hash = "sha256:42f43d49abd437ada79a79a16bd48a604d3471a117a8347e87db693f2ba0ba28"}, ] [package.dependencies] @@ -4786,7 +4849,7 @@ version = "3.26.1" description = "A lightweight library for converting complex datatypes to and from native Python datatypes." optional = false python-versions = ">=3.9" -groups = ["evaluation", "llama-index"] +groups = ["main", "evaluation", "llama-index"] files = [ {file = "marshmallow-3.26.1-py3-none-any.whl", hash = "sha256:3350409f20a70a7e4e11a27661187b77cdcaeb20abca41c1454fe33636bea09c"}, {file = "marshmallow-3.26.1.tar.gz", hash = "sha256:e6d8affb6cb61d39d26402096dc0aee12d5a26d490a121f118d2e81dc0719dc6"}, @@ -4933,14 +4996,14 @@ urllib3 = "*" [[package]] name = "mistune" -version = "3.1.1" +version = "3.1.2" description = "A sane and fast Markdown parser with useful plugins and renderers" optional = false python-versions = ">=3.8" groups = ["runtime"] files = [ - {file = "mistune-3.1.1-py3-none-any.whl", hash = "sha256:02106ac2aa4f66e769debbfa028509a275069dcffce0dfa578edd7b991ee700a"}, - {file = "mistune-3.1.1.tar.gz", hash = "sha256:e0740d635f515119f7d1feb6f9b192ee60f0cc649f80a8f944f905706a21654c"}, + {file = "mistune-3.1.2-py3-none-any.whl", hash = "sha256:4b47731332315cdca99e0ded46fc0004001c1299ff773dfb48fbe1fd226de319"}, + {file = "mistune-3.1.2.tar.gz", hash = "sha256:733bf018ba007e8b5f2d3a9eb624034f6ee26c4ea769a98ec533ee111d504dff"}, ] [[package]] @@ -8391,33 +8454,34 @@ pathspec = ">=0.10.1" [[package]] name = "scikit-image" -version = "0.25.1" +version = "0.25.2" description = "Image processing in Python" optional = false python-versions = ">=3.10" groups = ["evaluation"] files = [ - {file = "scikit_image-0.25.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:40763a3a089617e6f00f92d46b3475368b9783588a165c2aa854da95b66bb4ff"}, - {file = "scikit_image-0.25.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:7c6b69f33e5512ee7fc53361b064430f146583f08dc75317667e81d5f8fcd0c6"}, - {file = "scikit_image-0.25.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9187347d115776ff0ddba3e5d2a04638d291b1a62e3c315d17b71eea351cde8"}, - {file = "scikit_image-0.25.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdfca713979ad1873a4b55d94bb1eb4bc713f0c10165b261bf6f7e606f44a00c"}, - {file = "scikit_image-0.25.1-cp310-cp310-win_amd64.whl", hash = "sha256:167fb146de80bb2a1493d1a760a9ac81644a8a5de254c3dd12a95d1b662d819c"}, - {file = "scikit_image-0.25.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c1bde2d5f1dfb23b3c72ef9fcdb2dd5f42fa353e8bd606aea63590eba5e79565"}, - {file = "scikit_image-0.25.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5112d95cccaa45c434e57efc20c1f721ab439e516e2ed49709ddc2afb7c15c70"}, - {file = "scikit_image-0.25.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f5e313b028f5d7a9f3888ad825ddf4fb78913d7762891abb267b99244b4dd31"}, - {file = "scikit_image-0.25.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39ad76aeff754048dabaff83db752aa0655dee425f006678d14485471bdb459d"}, - {file = "scikit_image-0.25.1-cp311-cp311-win_amd64.whl", hash = "sha256:8dc8b06176c1a2316fa8bc539fd7e96155721628ae5cf51bc1a2c62cb9786581"}, - {file = "scikit_image-0.25.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ebf83699d60134909647395a0bf07db3859646de7192b088e656deda6bc15e95"}, - {file = "scikit_image-0.25.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:408086520eed036340e634ab7e4f648e00238f711bac61ab933efeb11464a238"}, - {file = "scikit_image-0.25.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bd709faa87795869ccd21f32490c37989ca5846571495822f4b9430fb42c34c"}, - {file = "scikit_image-0.25.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6b15c0265c072a46ff4720784d756d8f8e5d63567639aa8451f6673994d6846"}, - {file = "scikit_image-0.25.1-cp312-cp312-win_amd64.whl", hash = "sha256:a689a0d091e0bd97d7767309abdeb27c43be210d075abb34e71657add920c22b"}, - {file = "scikit_image-0.25.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f070f899d6572a125ab106c4b26d1a5fb784dc60ba6dea45c7816f08c3a4fb4d"}, - {file = "scikit_image-0.25.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:cc9538d8db7670878aa68ea79c0b1796b6c771085e8d50f5408ee617da3281b6"}, - {file = "scikit_image-0.25.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caa08d4fa851e1f421fcad8eac24d32f2810971dc61f1d72dc950ca9e9ec39b1"}, - {file = "scikit_image-0.25.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9923aa898b7921fbcf503d32574d48ed937a7cff45ce8587be4868b39676e18"}, - {file = "scikit_image-0.25.1-cp313-cp313-win_amd64.whl", hash = "sha256:6c7bba6773ab8c39ee8b1cbb17c7f98965bacdb8cd8da337942be6acc38fc562"}, - {file = "scikit_image-0.25.1.tar.gz", hash = "sha256:d4ab30540d114d37c35fe5c837f89b94aaba2a7643afae8354aa353319e9bbbb"}, + {file = "scikit_image-0.25.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d3278f586793176599df6a4cf48cb6beadae35c31e58dc01a98023af3dc31c78"}, + {file = "scikit_image-0.25.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:5c311069899ce757d7dbf1d03e32acb38bb06153236ae77fcd820fd62044c063"}, + {file = "scikit_image-0.25.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be455aa7039a6afa54e84f9e38293733a2622b8c2fb3362b822d459cc5605e99"}, + {file = "scikit_image-0.25.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4c464b90e978d137330be433df4e76d92ad3c5f46a22f159520ce0fdbea8a09"}, + {file = "scikit_image-0.25.2-cp310-cp310-win_amd64.whl", hash = "sha256:60516257c5a2d2f74387c502aa2f15a0ef3498fbeaa749f730ab18f0a40fd054"}, + {file = "scikit_image-0.25.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f4bac9196fb80d37567316581c6060763b0f4893d3aca34a9ede3825bc035b17"}, + {file = "scikit_image-0.25.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d989d64ff92e0c6c0f2018c7495a5b20e2451839299a018e0e5108b2680f71e0"}, + {file = "scikit_image-0.25.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2cfc96b27afe9a05bc92f8c6235321d3a66499995675b27415e0d0c76625173"}, + {file = "scikit_image-0.25.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24cc986e1f4187a12aa319f777b36008764e856e5013666a4a83f8df083c2641"}, + {file = "scikit_image-0.25.2-cp311-cp311-win_amd64.whl", hash = "sha256:b4f6b61fc2db6340696afe3db6b26e0356911529f5f6aee8c322aa5157490c9b"}, + {file = "scikit_image-0.25.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8db8dd03663112783221bf01ccfc9512d1cc50ac9b5b0fe8f4023967564719fb"}, + {file = "scikit_image-0.25.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:483bd8cc10c3d8a7a37fae36dfa5b21e239bd4ee121d91cad1f81bba10cfb0ed"}, + {file = "scikit_image-0.25.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d1e80107bcf2bf1291acfc0bf0425dceb8890abe9f38d8e94e23497cbf7ee0d"}, + {file = "scikit_image-0.25.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a17e17eb8562660cc0d31bb55643a4da996a81944b82c54805c91b3fe66f4824"}, + {file = "scikit_image-0.25.2-cp312-cp312-win_amd64.whl", hash = "sha256:bdd2b8c1de0849964dbc54037f36b4e9420157e67e45a8709a80d727f52c7da2"}, + {file = "scikit_image-0.25.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7efa888130f6c548ec0439b1a7ed7295bc10105458a421e9bf739b457730b6da"}, + {file = "scikit_image-0.25.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:dd8011efe69c3641920614d550f5505f83658fe33581e49bed86feab43a180fc"}, + {file = "scikit_image-0.25.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28182a9d3e2ce3c2e251383bdda68f8d88d9fff1a3ebe1eb61206595c9773341"}, + {file = "scikit_image-0.25.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8abd3c805ce6944b941cfed0406d88faeb19bab3ed3d4b50187af55cf24d147"}, + {file = "scikit_image-0.25.2-cp313-cp313-win_amd64.whl", hash = "sha256:64785a8acefee460ec49a354706db0b09d1f325674107d7fa3eadb663fb56d6f"}, + {file = "scikit_image-0.25.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:330d061bd107d12f8d68f1d611ae27b3b813b8cdb0300a71d07b1379178dd4cd"}, + {file = "scikit_image-0.25.2.tar.gz", hash = "sha256:e5a37e6cd4d0c018a7a55b9d601357e3382826d3888c10d0213fc63bff977dde"}, ] [package.dependencies] @@ -8427,16 +8491,16 @@ networkx = ">=3.0" numpy = ">=1.24" packaging = ">=21" pillow = ">=10.1" -scipy = ">=1.11.2" +scipy = ">=1.11.4" tifffile = ">=2022.8.12" [package.extras] -build = ["Cython (>=3.0.8)", "build (>=1.2.1)", "meson-python (>=0.16)", "ninja (>=1.11.1.1)", "numpy (>=2.0)", "pythran (>=0.16)", "setuptools (>=68)", "spin (==0.13)"] +build = ["Cython (>=3.0.8)", "build (>=1.2.1)", "meson-python (>=0.16)", "ninja (>=1.11.1.1)", "numpy (>=2.0)", "pythran (>=0.16)", "spin (==0.13)"] data = ["pooch (>=1.6.0)"] developer = ["ipython", "pre-commit", "tomli"] -docs = ["PyWavelets (>=1.6)", "dask[array] (>=2022.9.2)", "intersphinx-registry (>=0.2411.14)", "ipykernel", "ipywidgets", "kaleido (==0.2.1)", "matplotlib (>=3.7)", "myst-parser", "numpydoc (>=1.7)", "pandas (>=2.0)", "plotly (>=5.20)", "pooch (>=1.6)", "pydata-sphinx-theme (>=0.16)", "pytest-doctestplus", "scikit-learn (>=1.2)", "seaborn (>=0.11)", "sphinx (>=8.0)", "sphinx-copybutton", "sphinx-gallery[parallel] (>=0.18)", "sphinx_design (>=0.5)", "tifffile (>=2022.8.12)"] -optional = ["PyWavelets (>=1.6)", "SimpleITK", "astropy (>=5.0)", "cloudpickle (>=0.2.1)", "dask[array] (>=2021.1.0,!=2024.8.0)", "matplotlib (>=3.7)", "pooch (>=1.6.0)", "pyamg (>=5.2)", "scikit-learn (>=1.2)"] -test = ["asv", "numpydoc (>=1.7)", "pooch (>=1.6.0)", "pytest (>=7.0)", "pytest-cov (>=2.11.0)", "pytest-doctestplus", "pytest-faulthandler", "pytest-localserver"] +docs = ["PyWavelets (>=1.6)", "dask[array] (>=2023.2.0)", "intersphinx-registry (>=0.2411.14)", "ipykernel", "ipywidgets", "kaleido (==0.2.1)", "matplotlib (>=3.7)", "myst-parser", "numpydoc (>=1.7)", "pandas (>=2.0)", "plotly (>=5.20)", "pooch (>=1.6)", "pydata-sphinx-theme (>=0.16)", "pytest-doctestplus", "scikit-learn (>=1.2)", "seaborn (>=0.11)", "sphinx (>=8.0)", "sphinx-copybutton", "sphinx-gallery[parallel] (>=0.18)", "sphinx_design (>=0.5)", "tifffile (>=2022.8.12)"] +optional = ["PyWavelets (>=1.6)", "SimpleITK", "astropy (>=5.0)", "cloudpickle (>=1.1.1)", "dask[array] (>=2023.2.0)", "matplotlib (>=3.7)", "pooch (>=1.6.0)", "pyamg (>=5.2)", "scikit-learn (>=1.2)"] +test = ["asv", "numpydoc (>=1.7)", "pooch (>=1.6.0)", "pytest (>=8)", "pytest-cov (>=2.11.0)", "pytest-doctestplus", "pytest-faulthandler", "pytest-localserver"] [[package]] name = "scikit-learn" @@ -9206,14 +9270,14 @@ files = [ [[package]] name = "tifffile" -version = "2025.1.10" +version = "2025.2.18" description = "Read and write TIFF files" optional = false python-versions = ">=3.10" groups = ["evaluation"] files = [ - {file = "tifffile-2025.1.10-py3-none-any.whl", hash = "sha256:ed24cf4c99fb13b4f5fb29f8a0d5605e60558c950bccbdca2a6470732a27cfb3"}, - {file = "tifffile-2025.1.10.tar.gz", hash = "sha256:baaf0a3b87bf7ec375fa1537503353f70497eabe1bdde590f2e41cc0346e612f"}, + {file = "tifffile-2025.2.18-py3-none-any.whl", hash = "sha256:54b36c4d5e5b8d8920134413edfe5a7cfb1c7617bb50cddf7e2772edb7149043"}, + {file = "tifffile-2025.2.18.tar.gz", hash = "sha256:8d731789e691b468746c1615d989bc550ac93cf753e9210865222e90a5a95d11"}, ] [package.dependencies] @@ -10789,4 +10853,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"] [metadata] lock-version = "2.1" python-versions = "^3.12" -content-hash = "14998d54438fedacad9d82422003f46d0d7721bd50c2f8096657c15dce0f3edd" +content-hash = "39e0f069346a4d1e52193899989b79ea3e02f81d67fbb2ac0fdc87e70bd1008f" diff --git a/pyproject.toml b/pyproject.toml index 73e06eda02de..09d72d3a087c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ stripe = "^11.5.0" ipywidgets = "^8.1.5" qtconsole = "^5.6.1" memory-profiler = "^0.61.0" +daytona-sdk = "0.9.1" [tool.poetry.group.llama-index.dependencies] llama-index = "*" diff --git a/tests/runtime/conftest.py b/tests/runtime/conftest.py index 73b18680a11e..bb0c1eca696b 100644 --- a/tests/runtime/conftest.py +++ b/tests/runtime/conftest.py @@ -11,6 +11,7 @@ from openhands.core.logger import openhands_logger as logger from openhands.events import EventStream from openhands.runtime.base import Runtime +from openhands.runtime.impl.daytona.daytona_runtime import DaytonaRuntime from openhands.runtime.impl.docker.docker_runtime import DockerRuntime from openhands.runtime.impl.local.local_runtime import LocalRuntime from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime @@ -130,6 +131,8 @@ def get_runtime_classes() -> list[type[Runtime]]: return [RemoteRuntime] elif runtime.lower() == 'runloop': return [RunloopRuntime] + elif runtime.lower() == 'daytona': + return [DaytonaRuntime] else: raise ValueError(f'Invalid runtime: {runtime}') diff --git a/tests/unit/test_cli_sid.py b/tests/unit/test_cli_sid.py new file mode 100644 index 000000000000..939e45ef2b98 --- /dev/null +++ b/tests/unit/test_cli_sid.py @@ -0,0 +1,101 @@ +import asyncio +from argparse import Namespace +from pathlib import Path +from unittest.mock import AsyncMock, patch + +import pytest + +from openhands.core.cli import main +from openhands.core.config import AppConfig +from openhands.core.schema import AgentState +from openhands.events.event import EventSource +from openhands.events.observation import AgentStateChangedObservation + + +@pytest.fixture +def mock_runtime(): + with patch('openhands.core.cli.create_runtime') as mock_create_runtime: + mock_runtime_instance = AsyncMock() + # Mock the event stream with proper async methods + mock_runtime_instance.event_stream = AsyncMock() + mock_runtime_instance.event_stream.subscribe = AsyncMock() + mock_runtime_instance.event_stream.add_event = AsyncMock() + # Mock connect method to return immediately + mock_runtime_instance.connect = AsyncMock() + # Ensure status_callback is None + mock_runtime_instance.status_callback = None + mock_create_runtime.return_value = mock_runtime_instance + yield mock_runtime_instance + + +@pytest.fixture +def mock_agent(): + with patch('openhands.core.cli.create_agent') as mock_create_agent: + mock_agent_instance = AsyncMock() + mock_create_agent.return_value = mock_agent_instance + yield mock_agent_instance + + +@pytest.fixture +def mock_controller(): + with patch('openhands.core.cli.create_controller') as mock_create_controller: + mock_controller_instance = AsyncMock() + # Mock run_until_done to finish immediately + mock_controller_instance.run_until_done = AsyncMock(return_value=None) + mock_create_controller.return_value = (mock_controller_instance, None) + yield mock_controller_instance + + +@pytest.fixture +def task_file(tmp_path: Path) -> Path: + # Create a temporary file with our task + task_file = tmp_path / 'task.txt' + task_file.write_text('Ask me what your task is') + return task_file + + +@pytest.fixture +def mock_config(task_file: Path): + with patch('openhands.core.cli.parse_arguments') as mock_parse_args: + # Create a proper Namespace with our temporary task file + args = Namespace(file=str(task_file), task=None, directory=None) + mock_parse_args.return_value = args + with patch('openhands.core.cli.setup_config_from_args') as mock_setup_config: + mock_config = AppConfig() + mock_setup_config.return_value = mock_config + yield mock_config + + +@pytest.mark.asyncio +async def test_cli_session_id_output( + mock_runtime, mock_agent, mock_controller, mock_config, capsys +): + # status_callback is set when initializing the runtime + mock_controller.status_callback = None + + # Use input patch just for the exit command + with patch('builtins.input', return_value='exit'): + # Create a task for main + main_task = asyncio.create_task(main(asyncio.get_event_loop())) + + # Give it a moment to display the session ID + await asyncio.sleep(0.1) + + # Trigger agent state change to STOPPED to end the main loop + event = AgentStateChangedObservation( + content='Stop', agent_state=AgentState.STOPPED + ) + event._source = EventSource.AGENT + await mock_runtime.event_stream.add_event(event) + + # Wait for main to finish with a timeout + try: + await asyncio.wait_for(main_task, timeout=1.0) + except asyncio.TimeoutError: + main_task.cancel() + + # Check the output + captured = capsys.readouterr() + assert 'Session ID:' in captured.out + # Also verify that our task message was processed + assert 'Ask me what your task is' in str(mock_runtime.mock_calls) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 10f09447ba6c..0848dda676de 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -686,6 +686,7 @@ def test_api_keys_repr_str(): modal_api_token_id='my_modal_api_token_id', modal_api_token_secret='my_modal_api_token_secret', runloop_api_key='my_runloop_api_key', + daytona_api_key='my_daytona_api_key', ) assert 'my_e2b_api_key' not in repr(app_config) assert 'my_e2b_api_key' not in str(app_config) @@ -697,6 +698,8 @@ def test_api_keys_repr_str(): assert 'my_modal_api_token_secret' not in str(app_config) assert 'my_runloop_api_key' not in repr(app_config) assert 'my_runloop_api_key' not in str(app_config) + assert 'my_daytona_api_key' not in repr(app_config) + assert 'my_daytona_api_key' not in str(app_config) # Check that no other attrs in AppConfig have 'key' or 'token' in their name # This will fail when new attrs are added, and attract attention @@ -705,6 +708,7 @@ def test_api_keys_repr_str(): 'modal_api_token_id', 'modal_api_token_secret', 'runloop_api_key', + 'daytona_api_key', ] for attr_name in AppConfig.model_fields.keys(): if ( diff --git a/tests/unit/test_ensure_httpx_close.py b/tests/unit/test_ensure_httpx_close.py new file mode 100644 index 000000000000..7ef50b5535a5 --- /dev/null +++ b/tests/unit/test_ensure_httpx_close.py @@ -0,0 +1,84 @@ +from httpx import Client + +from openhands.utils.ensure_httpx_close import EnsureHttpxClose + + +def test_ensure_httpx_close_basic(): + """Test basic functionality of EnsureHttpxClose.""" + clients = [] + ctx = EnsureHttpxClose() + with ctx: + # Create a client - should be tracked + client = Client() + assert client in ctx.clients + assert len(ctx.clients) == 1 + clients.append(client) + + # After context exit, client should be closed + assert client.is_closed + + +def test_ensure_httpx_close_multiple_clients(): + """Test EnsureHttpxClose with multiple clients.""" + ctx = EnsureHttpxClose() + with ctx: + client1 = Client() + client2 = Client() + assert len(ctx.clients) == 2 + assert client1 in ctx.clients + assert client2 in ctx.clients + + assert client1.is_closed + assert client2.is_closed + + +def test_ensure_httpx_close_nested(): + """Test nested usage of EnsureHttpxClose.""" + outer_ctx = EnsureHttpxClose() + with outer_ctx: + client1 = Client() + assert client1 in outer_ctx.clients + + inner_ctx = EnsureHttpxClose() + with inner_ctx: + client2 = Client() + assert client2 in inner_ctx.clients + # Since both contexts are using the same monkey-patched __init__, + # both contexts will track all clients created while they are active + assert client2 in outer_ctx.clients + + # After inner context, client2 should be closed + assert client2.is_closed + # client1 should still be open since outer context is still active + assert not client1.is_closed + + # After outer context, both clients should be closed + assert client1.is_closed + assert client2.is_closed + + +def test_ensure_httpx_close_exception(): + """Test EnsureHttpxClose when an exception occurs.""" + client = None + ctx = EnsureHttpxClose() + try: + with ctx: + client = Client() + raise ValueError('Test exception') + except ValueError: + pass + + # Client should be closed even if an exception occurred + assert client is not None + assert client.is_closed + + +def test_ensure_httpx_close_restore_init(): + """Test that the original __init__ is restored after context exit.""" + original_init = Client.__init__ + ctx = EnsureHttpxClose() + with ctx: + assert Client.__init__ != original_init + + # Original __init__ should be restored + assert Client.__init__ == original_init diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py index 0ec7fe252192..57906c2c776c 100644 --- a/tests/unit/test_llm.py +++ b/tests/unit/test_llm.py @@ -1,4 +1,6 @@ import copy +import tempfile +from pathlib import Path from unittest.mock import MagicMock, patch import pytest @@ -489,3 +491,27 @@ def test_llm_token_usage(mock_litellm_completion, default_config): assert usage_entry_2['cache_read_tokens'] == 1 assert usage_entry_2['cache_write_tokens'] == 3 assert usage_entry_2['response_id'] == 'test-response-usage-2' + + +@patch('openhands.llm.llm.litellm_completion') +def test_completion_with_log_completions(mock_litellm_completion, default_config): + with tempfile.TemporaryDirectory() as temp_dir: + default_config.log_completions = True + default_config.log_completions_folder = temp_dir + mock_response = { + 'choices': [{'message': {'content': 'This is a mocked response.'}}] + } + mock_litellm_completion.return_value = mock_response + + test_llm = LLM(config=default_config) + response = test_llm.completion( + messages=[{'role': 'user', 'content': 'Hello!'}], + stream=False, + drop_params=True, + ) + assert ( + response['choices'][0]['message']['content'] == 'This is a mocked response.' + ) + files = list(Path(temp_dir).iterdir()) + # Expect a log to be generated + assert len(files) == 1 diff --git a/tests/unit/test_memory.py b/tests/unit/test_long_term_memory.py similarity index 100% rename from tests/unit/test_memory.py rename to tests/unit/test_long_term_memory.py