Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polish errors and miscelaneous fixes for workspaces #649

Merged
merged 5 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from fastapi import APIRouter, Response
from fastapi.exceptions import HTTPException
from fastapi import APIRouter, HTTPException, Response
from fastapi.routing import APIRoute
from pydantic import ValidationError

from codegate.api import v1_models
from codegate.db.connection import AlreadyExistsError
from codegate.workspaces.crud import WorkspaceCrud
from codegate.api.dashboard.dashboard import dashboard_router
from codegate.db.connection import AlreadyExistsError
from codegate.workspaces import crud

v1 = APIRouter()
v1.include_router(dashboard_router)

wscrud = WorkspaceCrud()
wscrud = crud.WorkspaceCrud()


def uniq_name(route: APIRoute):
Expand Down Expand Up @@ -44,21 +42,24 @@ async def list_active_workspaces() -> v1_models.ListActiveWorkspacesResponse:
@v1.post("/workspaces/active", tags=["Workspaces"], generate_unique_id_function=uniq_name)
async def activate_workspace(request: v1_models.ActivateWorkspaceRequest, status_code=204):
"""Activate a workspace by name."""
activated = await wscrud.activate_workspace(request.name)

# TODO: Refactor
if not activated:
try:
await wscrud.activate_workspace(request.name)
except crud.WorkspaceAlreadyActiveError:
return HTTPException(status_code=409, detail="Workspace already active")
except crud.WorkspaceDoesNotExistError:
return HTTPException(status_code=404, detail="Workspace does not exist")
except Exception:
return HTTPException(status_code=500, detail="Internal server error")

return Response(status_code=204)


@v1.post("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name, status_code=201)
async def create_workspace(request: v1_models.CreateWorkspaceRequest):
async def create_workspace(request: v1_models.CreateWorkspaceRequest) -> v1_models.Workspace:
"""Create a new workspace."""
# Input validation is done in the model
try:
created = await wscrud.add_workspace(request.name)
_ = await wscrud.add_workspace(request.name)
except AlreadyExistsError:
raise HTTPException(status_code=409, detail="Workspace already exists")
except ValidationError:
Expand All @@ -68,8 +69,7 @@ async def create_workspace(request: v1_models.CreateWorkspaceRequest):
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")

if created:
return v1_models.Workspace(name=created.name)
return v1_models.Workspace(name=request.name, is_active=False)


@v1.delete(
Expand Down
17 changes: 13 additions & 4 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ async def update_session(self, session: Session) -> Optional[Session]:
"""
)
# We only pass an object to respect the signature of the function
active_session = await self._execute_update_pydantic_model(session, sql)
active_session = await self._execute_update_pydantic_model(session, sql, should_raise=True)
return active_session


Expand Down Expand Up @@ -317,14 +317,18 @@ async def _execute_select_pydantic_model(
return None

async def _exec_select_conditions_to_pydantic(
self, model_type: Type[BaseModel], sql_command: TextClause, conditions: dict
self, model_type: Type[BaseModel], sql_command: TextClause, conditions: dict,
should_raise: bool = False
) -> Optional[List[BaseModel]]:
async with self._async_db_engine.begin() as conn:
try:
result = await conn.execute(sql_command, conditions)
return await self._dump_result_to_pydantic_model(model_type, result)
except Exception as e:
logger.error(f"Failed to select model with conditions: {model_type}.", error=str(e))
# Exposes errors to the caller
if should_raise:
raise e
return None

async def get_prompts_with_output(self) -> List[GetPromptWithOutputsRow]:
Expand Down Expand Up @@ -392,7 +396,8 @@ async def get_workspace_by_name(self, name: str) -> List[Workspace]:
"""
)
conditions = {"name": name}
workspaces = await self._exec_select_conditions_to_pydantic(Workspace, sql, conditions)
workspaces = await self._exec_select_conditions_to_pydantic(
Workspace, sql, conditions, should_raise=True)
return workspaces[0] if workspaces else None

async def get_sessions(self) -> List[Session]:
Expand Down Expand Up @@ -453,7 +458,11 @@ def init_session_if_not_exists(db_path: Optional[str] = None):
last_update=datetime.datetime.now(datetime.timezone.utc),
)
db_recorder = DbRecorder(db_path)
asyncio.run(db_recorder.update_session(session))
try:
asyncio.run(db_recorder.update_session(session))
except Exception as e:
logger.error(f"Failed to initialize session in DB: {e}")
return
logger.info("Session in DB initialized successfully.")


Expand Down
18 changes: 10 additions & 8 deletions src/codegate/pipeline/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from codegate import __version__
from codegate.db.connection import AlreadyExistsError
from codegate.workspaces.crud import WorkspaceCrud
from codegate.workspaces import crud


class CodegateCommand(ABC):
Expand Down Expand Up @@ -41,7 +41,7 @@ def help(self) -> str:
class Workspace(CodegateCommand):

def __init__(self):
self.workspace_crud = WorkspaceCrud()
self.workspace_crud = crud.WorkspaceCrud()
self.commands = {
"list": self._list_workspaces,
"add": self._add_workspace,
Expand Down Expand Up @@ -94,12 +94,14 @@ async def _activate_workspace(self, args: List[str]) -> str:
if not workspace_name:
return "Please provide a name. Use `codegate workspace activate workspace_name`"

was_activated = await self.workspace_crud.activate_workspace(workspace_name)
if not was_activated:
return (
f"Workspace **{workspace_name}** does not exist or was already active. "
f"Use `codegate workspace add {workspace_name}` to add it"
)
try:
await self.workspace_crud.activate_workspace(workspace_name)
except crud.WorkspaceAlreadyActiveError:
return f"Workspace **{workspace_name}** is already active"
except crud.WorkspaceDoesNotExistError:
return f"Workspace **{workspace_name}** does not exist"
except Exception:
return "An error occurred while activating the workspace"
return f"Workspace **{workspace_name}** has been activated"

async def run(self, args: List[str]) -> str:
Expand Down
31 changes: 17 additions & 14 deletions src/codegate/workspaces/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
class WorkspaceCrudError(Exception):
pass

class WorkspaceDoesNotExistError(WorkspaceCrudError):
pass

class WorkspaceAlreadyActiveError(WorkspaceCrudError):
pass

class WorkspaceCrud:

def __init__(self):
Expand Down Expand Up @@ -36,44 +42,41 @@ async def get_active_workspace(self) -> Optional[ActiveWorkspace]:
"""
return await self._db_reader.get_active_workspace()

async def _is_workspace_active_or_not_exist(
async def _is_workspace_active(
self, workspace_name: str
) -> Tuple[bool, Optional[Session], Optional[Workspace]]:
"""
Check if the workspace is active

Will return:
- True if the workspace was activated
- False if the workspace is already active or does not exist
Check if the workspace is active alongside the session and workspace objects
"""
# TODO: All of this should be done within a transaction.

selected_workspace = await self._db_reader.get_workspace_by_name(workspace_name)
if not selected_workspace:
return True, None, None
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")

sessions = await self._db_reader.get_sessions()
# The current implementation expects only one active session
if len(sessions) != 1:
raise RuntimeError("Something went wrong. No active session found.")

session = sessions[0]
if session.active_workspace_id == selected_workspace.id:
return True, None, None
return False, session, selected_workspace
return (session.active_workspace_id == selected_workspace.id,
session, selected_workspace)

async def activate_workspace(self, workspace_name: str) -> bool:
async def activate_workspace(self, workspace_name: str):
"""
Activate a workspace

Will return:
- True if the workspace was activated
- False if the workspace is already active or does not exist
"""
is_active, session, workspace = await self._is_workspace_active_or_not_exist(workspace_name)
is_active, session, workspace = await self._is_workspace_active(workspace_name)
if is_active:
return False
raise WorkspaceAlreadyActiveError(f"Workspace {workspace_name} is already active.")

session.active_workspace_id = workspace.id
session.last_update = datetime.datetime.now(datetime.timezone.utc)
db_recorder = DbRecorder()
await db_recorder.update_session(session)
return True
return
2 changes: 1 addition & 1 deletion tests/pipeline/workspace/test_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def test_add_workspaces(args, existing_workspaces, expected_message):

# We'll also patch DbRecorder to ensure no real DB operations happen
with patch(
"codegate.pipeline.cli.commands.WorkspaceCrud", autospec=True
"codegate.workspaces.crud.WorkspaceCrud", autospec=True
) as mock_recorder_cls:
mock_recorder = mock_recorder_cls.return_value
workspace_commands.workspace_crud = mock_recorder
Expand Down
Loading