Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Let the user add their own system prompts #643

Merged
merged 8 commits into from
Jan 20, 2025
26 changes: 26 additions & 0 deletions migrations/versions/a692c8b52308_add_workspace_system_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""add_workspace_system_prompt

Revision ID: a692c8b52308
Revises: 5c2f3eee5f90
Create Date: 2025-01-17 16:33:58.464223

"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "a692c8b52308"
down_revision: Union[str, None] = "5c2f3eee5f90"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Add column to workspaces table
op.execute("ALTER TABLE workspaces ADD COLUMN system_prompt TEXT DEFAULT NULL;")


def downgrade() -> None:
op.execute("ALTER TABLE workspaces DROP COLUMN system_prompt;")
26 changes: 20 additions & 6 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,16 +248,15 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
except Exception as e:
logger.error(f"Failed to record context: {context}.", error=str(e))

async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
async def add_workspace(self, workspace_name: str) -> Workspace:
"""Add a new workspace to the DB.

This handles validation and insertion of a new workspace.

It may raise a ValidationError if the workspace name is invalid.
or a AlreadyExistsError if the workspace already exists.
"""
workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name)

workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name, system_prompt=None)
sql = text(
"""
INSERT INTO workspaces (id, name)
Expand All @@ -275,6 +274,21 @@ async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
raise AlreadyExistsError(f"Workspace {workspace_name} already exists.")
return added_workspace

async def update_workspace(self, workspace: Workspace) -> Workspace:
sql = text(
"""
UPDATE workspaces SET
name = :name,
system_prompt = :system_prompt
WHERE id = :id
RETURNING *
"""
)
updated_workspace = await self._execute_update_pydantic_model(
workspace, sql, should_raise=True
)
return updated_workspace

async def update_session(self, session: Session) -> Optional[Session]:
sql = text(
"""
Expand Down Expand Up @@ -392,11 +406,11 @@ async def get_workspaces(self) -> List[WorkspaceActive]:
workspaces = await self._execute_select_pydantic_model(WorkspaceActive, sql)
return workspaces

async def get_workspace_by_name(self, name: str) -> List[Workspace]:
async def get_workspace_by_name(self, name: str) -> Optional[Workspace]:
sql = text(
"""
SELECT
id, name
id, name, system_prompt
FROM workspaces
WHERE name = :name
"""
Expand All @@ -422,7 +436,7 @@ async def get_active_workspace(self) -> Optional[ActiveWorkspace]:
sql = text(
"""
SELECT
w.id, w.name, s.id as session_id, s.last_update
w.id, w.name, w.system_prompt, s.id as session_id, s.last_update
FROM sessions s
INNER JOIN workspaces w ON w.id = s.active_workspace_id
"""
Expand Down
2 changes: 2 additions & 0 deletions src/codegate/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Setting(BaseModel):
class Workspace(BaseModel):
id: str
name: str
system_prompt: Optional[str]

@field_validator("name", mode="plain")
@classmethod
Expand Down Expand Up @@ -98,5 +99,6 @@ class WorkspaceActive(BaseModel):
class ActiveWorkspace(BaseModel):
id: str
name: str
system_prompt: Optional[str]
session_id: str
last_update: datetime.datetime
3 changes: 2 additions & 1 deletion src/codegate/pipeline/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
PipelineResult,
PipelineStep,
)
from codegate.pipeline.cli.commands import Version, Workspace
from codegate.pipeline.cli.commands import SystemPrompt, Version, Workspace

HELP_TEXT = """
## CodeGate CLI\n
Expand All @@ -32,6 +32,7 @@ async def codegate_cli(command):
available_commands = {
"version": Version().exec,
"workspace": Workspace().exec,
"system-prompt": SystemPrompt().exec,
}
out_func = available_commands.get(command[0])
if out_func is None:
Expand Down
Loading
Loading