diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 85892c43..358e4859 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -1,17 +1,15 @@ -from fastapi import APIRouter, Response -from fastapi.exceptions import HTTPException +from fastapi import APIRouter, HTTPException, Response from fastapi.routing import APIRoute from pydantic import ValidationError from codegate.api import v1_models -from codegate.db.connection import AlreadyExistsError -from codegate.workspaces.crud import WorkspaceCrud from codegate.api.dashboard.dashboard import dashboard_router +from codegate.db.connection import AlreadyExistsError +from codegate.workspaces import crud v1 = APIRouter() v1.include_router(dashboard_router) - -wscrud = WorkspaceCrud() +wscrud = crud.WorkspaceCrud() def uniq_name(route: APIRoute): @@ -44,21 +42,24 @@ async def list_active_workspaces() -> v1_models.ListActiveWorkspacesResponse: @v1.post("/workspaces/active", tags=["Workspaces"], generate_unique_id_function=uniq_name) async def activate_workspace(request: v1_models.ActivateWorkspaceRequest, status_code=204): """Activate a workspace by name.""" - activated = await wscrud.activate_workspace(request.name) - - # TODO: Refactor - if not activated: + try: + await wscrud.activate_workspace(request.name) + except crud.WorkspaceAlreadyActiveError: return HTTPException(status_code=409, detail="Workspace already active") + except crud.WorkspaceDoesNotExistError: + return HTTPException(status_code=404, detail="Workspace does not exist") + except Exception: + return HTTPException(status_code=500, detail="Internal server error") return Response(status_code=204) @v1.post("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name, status_code=201) -async def create_workspace(request: v1_models.CreateWorkspaceRequest): +async def create_workspace(request: v1_models.CreateWorkspaceRequest) -> v1_models.Workspace: """Create a new workspace.""" # Input validation is done in the model try: - created = await wscrud.add_workspace(request.name) + _ = await wscrud.add_workspace(request.name) except AlreadyExistsError: raise HTTPException(status_code=409, detail="Workspace already exists") except ValidationError: @@ -68,8 +69,7 @@ async def create_workspace(request: v1_models.CreateWorkspaceRequest): except Exception: raise HTTPException(status_code=500, detail="Internal server error") - if created: - return v1_models.Workspace(name=created.name) + return v1_models.Workspace(name=request.name, is_active=False) @v1.delete( diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 111c582f..6efe4586 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -284,7 +284,7 @@ async def update_session(self, session: Session) -> Optional[Session]: """ ) # We only pass an object to respect the signature of the function - active_session = await self._execute_update_pydantic_model(session, sql) + active_session = await self._execute_update_pydantic_model(session, sql, should_raise=True) return active_session @@ -317,7 +317,8 @@ async def _execute_select_pydantic_model( return None async def _exec_select_conditions_to_pydantic( - self, model_type: Type[BaseModel], sql_command: TextClause, conditions: dict + self, model_type: Type[BaseModel], sql_command: TextClause, conditions: dict, + should_raise: bool = False ) -> Optional[List[BaseModel]]: async with self._async_db_engine.begin() as conn: try: @@ -325,6 +326,9 @@ async def _exec_select_conditions_to_pydantic( return await self._dump_result_to_pydantic_model(model_type, result) except Exception as e: logger.error(f"Failed to select model with conditions: {model_type}.", error=str(e)) + # Exposes errors to the caller + if should_raise: + raise e return None async def get_prompts_with_output(self) -> List[GetPromptWithOutputsRow]: @@ -392,7 +396,8 @@ async def get_workspace_by_name(self, name: str) -> List[Workspace]: """ ) conditions = {"name": name} - workspaces = await self._exec_select_conditions_to_pydantic(Workspace, sql, conditions) + workspaces = await self._exec_select_conditions_to_pydantic( + Workspace, sql, conditions, should_raise=True) return workspaces[0] if workspaces else None async def get_sessions(self) -> List[Session]: @@ -453,7 +458,11 @@ def init_session_if_not_exists(db_path: Optional[str] = None): last_update=datetime.datetime.now(datetime.timezone.utc), ) db_recorder = DbRecorder(db_path) - asyncio.run(db_recorder.update_session(session)) + try: + asyncio.run(db_recorder.update_session(session)) + except Exception as e: + logger.error(f"Failed to initialize session in DB: {e}") + return logger.info("Session in DB initialized successfully.") diff --git a/src/codegate/pipeline/cli/commands.py b/src/codegate/pipeline/cli/commands.py index bc33bf18..f5a5d694 100644 --- a/src/codegate/pipeline/cli/commands.py +++ b/src/codegate/pipeline/cli/commands.py @@ -5,7 +5,7 @@ from codegate import __version__ from codegate.db.connection import AlreadyExistsError -from codegate.workspaces.crud import WorkspaceCrud +from codegate.workspaces import crud class CodegateCommand(ABC): @@ -41,7 +41,7 @@ def help(self) -> str: class Workspace(CodegateCommand): def __init__(self): - self.workspace_crud = WorkspaceCrud() + self.workspace_crud = crud.WorkspaceCrud() self.commands = { "list": self._list_workspaces, "add": self._add_workspace, @@ -94,12 +94,14 @@ async def _activate_workspace(self, args: List[str]) -> str: if not workspace_name: return "Please provide a name. Use `codegate workspace activate workspace_name`" - was_activated = await self.workspace_crud.activate_workspace(workspace_name) - if not was_activated: - return ( - f"Workspace **{workspace_name}** does not exist or was already active. " - f"Use `codegate workspace add {workspace_name}` to add it" - ) + try: + await self.workspace_crud.activate_workspace(workspace_name) + except crud.WorkspaceAlreadyActiveError: + return f"Workspace **{workspace_name}** is already active" + except crud.WorkspaceDoesNotExistError: + return f"Workspace **{workspace_name}** does not exist" + except Exception: + return "An error occurred while activating the workspace" return f"Workspace **{workspace_name}** has been activated" async def run(self, args: List[str]) -> str: diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index 6097c395..0e215cbb 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -8,6 +8,12 @@ class WorkspaceCrudError(Exception): pass +class WorkspaceDoesNotExistError(WorkspaceCrudError): + pass + +class WorkspaceAlreadyActiveError(WorkspaceCrudError): + pass + class WorkspaceCrud: def __init__(self): @@ -36,19 +42,17 @@ async def get_active_workspace(self) -> Optional[ActiveWorkspace]: """ return await self._db_reader.get_active_workspace() - async def _is_workspace_active_or_not_exist( + async def _is_workspace_active( self, workspace_name: str ) -> Tuple[bool, Optional[Session], Optional[Workspace]]: """ - Check if the workspace is active - - Will return: - - True if the workspace was activated - - False if the workspace is already active or does not exist + Check if the workspace is active alongside the session and workspace objects """ + # TODO: All of this should be done within a transaction. + selected_workspace = await self._db_reader.get_workspace_by_name(workspace_name) if not selected_workspace: - return True, None, None + raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") sessions = await self._db_reader.get_sessions() # The current implementation expects only one active session @@ -56,11 +60,10 @@ async def _is_workspace_active_or_not_exist( raise RuntimeError("Something went wrong. No active session found.") session = sessions[0] - if session.active_workspace_id == selected_workspace.id: - return True, None, None - return False, session, selected_workspace + return (session.active_workspace_id == selected_workspace.id, + session, selected_workspace) - async def activate_workspace(self, workspace_name: str) -> bool: + async def activate_workspace(self, workspace_name: str): """ Activate a workspace @@ -68,12 +71,12 @@ async def activate_workspace(self, workspace_name: str) -> bool: - True if the workspace was activated - False if the workspace is already active or does not exist """ - is_active, session, workspace = await self._is_workspace_active_or_not_exist(workspace_name) + is_active, session, workspace = await self._is_workspace_active(workspace_name) if is_active: - return False + raise WorkspaceAlreadyActiveError(f"Workspace {workspace_name} is already active.") session.active_workspace_id = workspace.id session.last_update = datetime.datetime.now(datetime.timezone.utc) db_recorder = DbRecorder() await db_recorder.update_session(session) - return True + return diff --git a/tests/pipeline/workspace/test_workspace.py b/tests/pipeline/workspace/test_workspace.py index 039b69fe..1caf2be2 100644 --- a/tests/pipeline/workspace/test_workspace.py +++ b/tests/pipeline/workspace/test_workspace.py @@ -78,7 +78,7 @@ async def test_add_workspaces(args, existing_workspaces, expected_message): # We'll also patch DbRecorder to ensure no real DB operations happen with patch( - "codegate.pipeline.cli.commands.WorkspaceCrud", autospec=True + "codegate.workspaces.crud.WorkspaceCrud", autospec=True ) as mock_recorder_cls: mock_recorder = mock_recorder_cls.return_value workspace_commands.workspace_crud = mock_recorder