Skip to content

Commit

Permalink
Merge branch 'main' into staging/deploy
Browse files Browse the repository at this point in the history
  • Loading branch information
scott-cohere authored Aug 22, 2024
2 parents dd5bc56 + 95d5639 commit 4aea9fe
Show file tree
Hide file tree
Showing 33 changed files with 688 additions and 112 deletions.
61 changes: 32 additions & 29 deletions src/backend/alembic/versions/2024_08_19_c301506b3676_.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,55 @@
Create Date: 2024-08-19 12:36:34.118536
"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = 'c301506b3676'
down_revision: Union[str, None] = 'a76ebb869eb8'
revision: str = "c301506b3676"
down_revision: Union[str, None] = "a76ebb869eb8"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('groups',
sa.Column('display_name', sa.String(), nullable=False),
sa.Column('id', sa.String(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('display_name', name='unique_display_name')
op.create_table(
"groups",
sa.Column("display_name", sa.String(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=True),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("display_name", name="unique_display_name"),
)
op.create_table('user_group',
sa.Column('user_id', sa.String(), nullable=False),
sa.Column('group_id', sa.String(), nullable=False),
sa.Column('display', sa.String(), nullable=False),
sa.Column('id', sa.String(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('user_id', 'group_id', 'id')
op.create_table(
"user_group",
sa.Column("user_id", sa.String(), nullable=False),
sa.Column("group_id", sa.String(), nullable=False),
sa.Column("display", sa.String(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=True),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(["group_id"], ["groups.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("user_id", "group_id", "id"),
)
op.add_column('users', sa.Column('user_name', sa.String(), nullable=True))
op.add_column('users', sa.Column('external_id', sa.String(), nullable=True))
op.add_column('users', sa.Column('active', sa.Boolean(), nullable=True))
op.create_unique_constraint('unique_user_name', 'users', ['user_name'])
op.add_column("users", sa.Column("user_name", sa.String(), nullable=True))
op.add_column("users", sa.Column("external_id", sa.String(), nullable=True))
op.add_column("users", sa.Column("active", sa.Boolean(), nullable=True))
op.create_unique_constraint("unique_user_name", "users", ["user_name"])
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint('unique_user_name', 'users', type_='unique')
op.drop_column('users', 'active')
op.drop_column('users', 'external_id')
op.drop_column('users', 'user_name')
op.drop_table('user_group')
op.drop_table('groups')
op.drop_constraint("unique_user_name", "users", type_="unique")
op.drop_column("users", "active")
op.drop_column("users", "external_id")
op.drop_column("users", "user_name")
op.drop_table("user_group")
op.drop_table("groups")
# ### end Alembic commands ###
23 changes: 23 additions & 0 deletions src/backend/config/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from backend.services.request_validators import (
validate_chat_request,
validate_organization_header,
validate_user_header,
)

Expand All @@ -35,104 +36,126 @@ class RouterName(StrEnum):
RouterName.AUTH: {
"default": [
Depends(get_session),
Depends(validate_organization_header),
],
"auth": [
Depends(get_session),
Depends(validate_organization_header),
],
},
RouterName.CHAT: {
"default": [
Depends(get_session),
Depends(validate_user_header),
Depends(validate_chat_request),
Depends(validate_organization_header),
],
"auth": [
Depends(get_session),
Depends(validate_chat_request),
Depends(validate_authorization),
Depends(validate_organization_header),
],
},
RouterName.CONVERSATION: {
"default": [
Depends(get_session),
Depends(validate_user_header),
Depends(validate_organization_header),
],
"auth": [
Depends(get_session),
Depends(validate_authorization),
Depends(validate_organization_header),
],
},
RouterName.DEPLOYMENT: {
"default": [
Depends(get_session),
Depends(validate_organization_header),
],
"auth": [
Depends(get_session),
Depends(validate_authorization),
Depends(validate_organization_header),
],
},
RouterName.EXPERIMENTAL_FEATURES: {
"default": [
Depends(get_session),
Depends(validate_organization_header),
],
"auth": [
Depends(get_session),
Depends(validate_authorization),
Depends(validate_organization_header),
],
},
RouterName.TOOL: {
"default": [
Depends(get_session),
Depends(validate_organization_header),
],
"auth": [
Depends(get_session),
Depends(validate_authorization),
Depends(validate_organization_header),
],
},
RouterName.USER: {
"default": [
Depends(get_session),
Depends(validate_organization_header),
],
"auth": [
# TODO: Remove auth only for create user endpoint
Depends(get_session),
Depends(validate_organization_header),
],
},
RouterName.AGENT: {
"default": [
Depends(get_session),
Depends(validate_organization_header),
],
"auth": [
Depends(get_session),
Depends(validate_authorization),
Depends(validate_organization_header),
],
},
RouterName.DEFAULT_AGENT: {
"default": [
Depends(get_session),
Depends(validate_organization_header),
],
"auth": [
Depends(get_session),
Depends(validate_authorization),
Depends(validate_organization_header),
],
},
RouterName.SNAPSHOT: {
"default": [
Depends(get_session),
Depends(validate_user_header),
Depends(validate_organization_header),
],
"auth": [
Depends(get_session),
Depends(validate_authorization),
Depends(validate_organization_header),
],
},
RouterName.MODEL: {
"default": [
Depends(get_session),
Depends(validate_organization_header),
],
"auth": [
Depends(get_session),
Depends(validate_authorization),
Depends(validate_organization_header),
],
},
RouterName.SCIM: {
Expand Down
2 changes: 1 addition & 1 deletion src/backend/crud/organization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlalchemy.orm import Session

from backend.database_models import Agent
from backend.database_models.agent import Agent
from backend.database_models.organization import Organization
from backend.database_models.user import User, UserOrganizationAssociation
from backend.schemas.organization import UpdateOrganization
Expand Down
40 changes: 39 additions & 1 deletion src/backend/database_models/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,45 @@
from enum import StrEnum
from uuid import uuid4

from sqlalchemy import DateTime, String, func
from sqlalchemy.orm import DeclarativeBase, mapped_column
from sqlalchemy.orm import DeclarativeBase, Query, mapped_column


class FilterFields(StrEnum):
ORGANIZATION_ID = "organization_id"


class CustomFilterQuery(Query):
"""
Custom query class that filters by field if the entity has field
and the filter value is set.
"""

ALLOWED_FILTER_FIELDS = [FilterFields.ORGANIZATION_ID]

def __new__(cls, *args, **kwargs):
from backend.services.context import GLOBAL_REQUEST_CONTEXT

request_ctx = GLOBAL_REQUEST_CONTEXT.get()
if request_ctx and request_ctx.use_global_filtering:
query = None
for field in cls.ALLOWED_FILTER_FIELDS:
if (
args
and hasattr(args[0][0], field)
and hasattr(request_ctx, field)
and getattr(request_ctx, field)
):
if query:
query = query.filter_by(**{field: getattr(request_ctx, field)})
else:
query = Query(*args, **kwargs).filter_by(
**{field: getattr(request_ctx, field)}
)
if query:
return query

return object.__new__(cls)


class Base(DeclarativeBase):
Expand Down
3 changes: 2 additions & 1 deletion src/backend/database_models/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sqlalchemy.orm import Session

from backend.config.settings import Settings
from backend.database_models.base import CustomFilterQuery

load_dotenv()

Expand All @@ -16,7 +17,7 @@


def get_session() -> Generator[Session, Any, None]:
with Session(engine) as session:
with Session(engine, query_cls=CustomFilterQuery) as session:
yield session


Expand Down
3 changes: 3 additions & 0 deletions src/backend/routers/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ async def list_agents(
# TODO: get organization_id from user
user_id = ctx.get_user_id()
logger = ctx.get_logger()
# request organization_id is used for filtering agents instead of header Organization-Id if enabled
if organization_id:
ctx.without_global_filtering()

try:
agents = agent_crud.get_agents(
Expand Down
28 changes: 25 additions & 3 deletions src/backend/routers/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
UpdateOrganization,
)
from backend.schemas.context import Context
from backend.schemas.user import User
from backend.services.context import get_context
from backend.services.request_validators import validate_organization_request

Expand Down Expand Up @@ -88,11 +89,11 @@ def get_organization(
session (DBSessionDep): Database session.
Returns:
ManagedTool: Tool with the given ID.
ManagedTool: Organization with the given ID.
"""
organization = organization_crud.get_organization(session, organization_id)
if not organization:
raise HTTPException(status_code=404, detail="Model not found")
raise HTTPException(status_code=404, detail="Organization not found")
return organization


Expand All @@ -114,7 +115,7 @@ def delete_organization(
"""
organization = organization_crud.get_organization(session, organization_id)
if not organization:
raise HTTPException(status_code=404, detail="Tool not found")
raise HTTPException(status_code=404, detail="Organization not found")
organization_crud.delete_organization(session, organization_id)

return DeleteOrganization()
Expand All @@ -138,3 +139,24 @@ def list_organizations(
"""
all_organizations = organization_crud.get_organizations(session)
return all_organizations


@router.get("/{organization_id}/users", response_model=list[User])
def get_organization_users(
organization_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
) -> list[User]:
"""
Get organization users by ID.
Args:
organization_id (str): Organization ID.
session (DBSessionDep): Database session.
Returns:
list[User]: List of users in the organization
"""
organization = organization_crud.get_organization(session, organization_id)
if not organization:
raise HTTPException(status_code=404, detail="Organization not found")

return organization.users
Loading

0 comments on commit 4aea9fe

Please sign in to comment.