diff --git a/src/backend/alembic/versions/2024_08_19_c301506b3676_.py b/src/backend/alembic/versions/2024_08_19_c301506b3676_.py index afb1a1b1e7..aaa7fc9513 100644 --- a/src/backend/alembic/versions/2024_08_19_c301506b3676_.py +++ b/src/backend/alembic/versions/2024_08_19_c301506b3676_.py @@ -5,52 +5,55 @@ Create Date: 2024-08-19 12:36:34.118536 """ + from typing import Sequence, Union import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision: str = 'c301506b3676' -down_revision: Union[str, None] = 'a76ebb869eb8' +revision: str = "c301506b3676" +down_revision: Union[str, None] = "a76ebb869eb8" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table('groups', - sa.Column('display_name', sa.String(), nullable=False), - sa.Column('id', sa.String(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('display_name', name='unique_display_name') + op.create_table( + "groups", + sa.Column("display_name", sa.String(), nullable=False), + sa.Column("id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=True), + sa.Column("updated_at", sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("display_name", name="unique_display_name"), ) - op.create_table('user_group', - sa.Column('user_id', sa.String(), nullable=False), - sa.Column('group_id', sa.String(), nullable=False), - sa.Column('display', sa.String(), nullable=False), - sa.Column('id', sa.String(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.Column('updated_at', sa.DateTime(), nullable=True), - sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('user_id', 'group_id', 'id') + op.create_table( + "user_group", + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("group_id", sa.String(), nullable=False), + sa.Column("display", sa.String(), nullable=False), + sa.Column("id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=True), + sa.Column("updated_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(["group_id"], ["groups.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("user_id", "group_id", "id"), ) - op.add_column('users', sa.Column('user_name', sa.String(), nullable=True)) - op.add_column('users', sa.Column('external_id', sa.String(), nullable=True)) - op.add_column('users', sa.Column('active', sa.Boolean(), nullable=True)) - op.create_unique_constraint('unique_user_name', 'users', ['user_name']) + op.add_column("users", sa.Column("user_name", sa.String(), nullable=True)) + op.add_column("users", sa.Column("external_id", sa.String(), nullable=True)) + op.add_column("users", sa.Column("active", sa.Boolean(), nullable=True)) + op.create_unique_constraint("unique_user_name", "users", ["user_name"]) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint('unique_user_name', 'users', type_='unique') - op.drop_column('users', 'active') - op.drop_column('users', 'external_id') - op.drop_column('users', 'user_name') - op.drop_table('user_group') - op.drop_table('groups') + op.drop_constraint("unique_user_name", "users", type_="unique") + op.drop_column("users", "active") + op.drop_column("users", "external_id") + op.drop_column("users", "user_name") + op.drop_table("user_group") + op.drop_table("groups") # ### end Alembic commands ### diff --git a/src/backend/config/routers.py b/src/backend/config/routers.py index 03c7c06fa5..79f988373a 100644 --- a/src/backend/config/routers.py +++ b/src/backend/config/routers.py @@ -9,6 +9,7 @@ ) from backend.services.request_validators import ( validate_chat_request, + validate_organization_header, validate_user_header, ) @@ -35,9 +36,11 @@ class RouterName(StrEnum): RouterName.AUTH: { "default": [ Depends(get_session), + Depends(validate_organization_header), ], "auth": [ Depends(get_session), + Depends(validate_organization_header), ], }, RouterName.CHAT: { @@ -45,94 +48,114 @@ class RouterName(StrEnum): Depends(get_session), Depends(validate_user_header), Depends(validate_chat_request), + Depends(validate_organization_header), ], "auth": [ Depends(get_session), Depends(validate_chat_request), Depends(validate_authorization), + Depends(validate_organization_header), ], }, RouterName.CONVERSATION: { "default": [ Depends(get_session), Depends(validate_user_header), + Depends(validate_organization_header), ], "auth": [ Depends(get_session), Depends(validate_authorization), + Depends(validate_organization_header), ], }, RouterName.DEPLOYMENT: { "default": [ Depends(get_session), + Depends(validate_organization_header), ], "auth": [ Depends(get_session), Depends(validate_authorization), + Depends(validate_organization_header), ], }, RouterName.EXPERIMENTAL_FEATURES: { "default": [ Depends(get_session), + Depends(validate_organization_header), ], "auth": [ Depends(get_session), Depends(validate_authorization), + Depends(validate_organization_header), ], }, RouterName.TOOL: { "default": [ Depends(get_session), + Depends(validate_organization_header), ], "auth": [ Depends(get_session), Depends(validate_authorization), + Depends(validate_organization_header), ], }, RouterName.USER: { "default": [ Depends(get_session), + Depends(validate_organization_header), ], "auth": [ # TODO: Remove auth only for create user endpoint Depends(get_session), + Depends(validate_organization_header), ], }, RouterName.AGENT: { "default": [ Depends(get_session), + Depends(validate_organization_header), ], "auth": [ Depends(get_session), Depends(validate_authorization), + Depends(validate_organization_header), ], }, RouterName.DEFAULT_AGENT: { "default": [ Depends(get_session), + Depends(validate_organization_header), ], "auth": [ Depends(get_session), Depends(validate_authorization), + Depends(validate_organization_header), ], }, RouterName.SNAPSHOT: { "default": [ Depends(get_session), Depends(validate_user_header), + Depends(validate_organization_header), ], "auth": [ Depends(get_session), Depends(validate_authorization), + Depends(validate_organization_header), ], }, RouterName.MODEL: { "default": [ Depends(get_session), + Depends(validate_organization_header), ], "auth": [ Depends(get_session), Depends(validate_authorization), + Depends(validate_organization_header), ], }, RouterName.SCIM: { diff --git a/src/backend/crud/organization.py b/src/backend/crud/organization.py index 3403c88d89..b974f6780b 100644 --- a/src/backend/crud/organization.py +++ b/src/backend/crud/organization.py @@ -1,6 +1,6 @@ from sqlalchemy.orm import Session -from backend.database_models import Agent +from backend.database_models.agent import Agent from backend.database_models.organization import Organization from backend.database_models.user import User, UserOrganizationAssociation from backend.schemas.organization import UpdateOrganization diff --git a/src/backend/database_models/base.py b/src/backend/database_models/base.py index 32223851d9..3245f14da1 100644 --- a/src/backend/database_models/base.py +++ b/src/backend/database_models/base.py @@ -1,7 +1,45 @@ +from enum import StrEnum from uuid import uuid4 from sqlalchemy import DateTime, String, func -from sqlalchemy.orm import DeclarativeBase, mapped_column +from sqlalchemy.orm import DeclarativeBase, Query, mapped_column + + +class FilterFields(StrEnum): + ORGANIZATION_ID = "organization_id" + + +class CustomFilterQuery(Query): + """ + Custom query class that filters by field if the entity has field + and the filter value is set. + """ + + ALLOWED_FILTER_FIELDS = [FilterFields.ORGANIZATION_ID] + + def __new__(cls, *args, **kwargs): + from backend.services.context import GLOBAL_REQUEST_CONTEXT + + request_ctx = GLOBAL_REQUEST_CONTEXT.get() + if request_ctx and request_ctx.use_global_filtering: + query = None + for field in cls.ALLOWED_FILTER_FIELDS: + if ( + args + and hasattr(args[0][0], field) + and hasattr(request_ctx, field) + and getattr(request_ctx, field) + ): + if query: + query = query.filter_by(**{field: getattr(request_ctx, field)}) + else: + query = Query(*args, **kwargs).filter_by( + **{field: getattr(request_ctx, field)} + ) + if query: + return query + + return object.__new__(cls) class Base(DeclarativeBase): diff --git a/src/backend/database_models/database.py b/src/backend/database_models/database.py index 93ef2122f8..7f6ec597f0 100644 --- a/src/backend/database_models/database.py +++ b/src/backend/database_models/database.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Session from backend.config.settings import Settings +from backend.database_models.base import CustomFilterQuery load_dotenv() @@ -16,7 +17,7 @@ def get_session() -> Generator[Session, Any, None]: - with Session(engine) as session: + with Session(engine, query_cls=CustomFilterQuery) as session: yield session diff --git a/src/backend/routers/agent.py b/src/backend/routers/agent.py index 8b57be96ff..97080077cc 100644 --- a/src/backend/routers/agent.py +++ b/src/backend/routers/agent.py @@ -193,6 +193,9 @@ async def list_agents( # TODO: get organization_id from user user_id = ctx.get_user_id() logger = ctx.get_logger() + # request organization_id is used for filtering agents instead of header Organization-Id if enabled + if organization_id: + ctx.without_global_filtering() try: agents = agent_crud.get_agents( diff --git a/src/backend/routers/organization.py b/src/backend/routers/organization.py index 274e992447..f1a14c2512 100644 --- a/src/backend/routers/organization.py +++ b/src/backend/routers/organization.py @@ -11,6 +11,7 @@ UpdateOrganization, ) from backend.schemas.context import Context +from backend.schemas.user import User from backend.services.context import get_context from backend.services.request_validators import validate_organization_request @@ -88,11 +89,11 @@ def get_organization( session (DBSessionDep): Database session. Returns: - ManagedTool: Tool with the given ID. + ManagedTool: Organization with the given ID. """ organization = organization_crud.get_organization(session, organization_id) if not organization: - raise HTTPException(status_code=404, detail="Model not found") + raise HTTPException(status_code=404, detail="Organization not found") return organization @@ -114,7 +115,7 @@ def delete_organization( """ organization = organization_crud.get_organization(session, organization_id) if not organization: - raise HTTPException(status_code=404, detail="Tool not found") + raise HTTPException(status_code=404, detail="Organization not found") organization_crud.delete_organization(session, organization_id) return DeleteOrganization() @@ -138,3 +139,24 @@ def list_organizations( """ all_organizations = organization_crud.get_organizations(session) return all_organizations + + +@router.get("/{organization_id}/users", response_model=list[User]) +def get_organization_users( + organization_id: str, session: DBSessionDep, ctx: Context = Depends(get_context) +) -> list[User]: + """ + Get organization users by ID. + + Args: + organization_id (str): Organization ID. + session (DBSessionDep): Database session. + + Returns: + list[User]: List of users in the organization + """ + organization = organization_crud.get_organization(session, organization_id) + if not organization: + raise HTTPException(status_code=404, detail="Organization not found") + + return organization.users diff --git a/src/backend/schemas/context.py b/src/backend/schemas/context.py index 8a7e0c9c30..eefde4142a 100644 --- a/src/backend/schemas/context.py +++ b/src/backend/schemas/context.py @@ -2,8 +2,10 @@ from pydantic import BaseModel +from backend.crud import organization as organization_crud from backend.crud import user as user_crud from backend.database_models.database import DBSessionDep +from backend.schemas import Organization from backend.schemas.agent import Agent, AgentToolMetadata from backend.schemas.metrics import MetricsAgent, MetricsMessageType, MetricsUser from backend.schemas.user import User @@ -28,6 +30,9 @@ class Context(BaseModel): agent_id: Optional[str] = None stream_start_ms: Optional[float] = None logger: Optional[Any] = None + organization_id: Optional[str] = None + organization: Optional[Organization] = None + use_global_filtering: Optional[bool] = False # Metrics metrics_user: Optional[MetricsUser] = None @@ -125,6 +130,42 @@ def with_agent_id(self, agent_id: str) -> "Context": self.agent_id = agent_id return self + def with_organization_id(self, organization_id: str) -> "Context": + self.organization_id = organization_id + return self + + def with_organization( + self, + session: DBSessionDep | None = None, + organization: Organization | None = None, + ) -> "Context": + if not organization and not session: + return self + + if not organization: + organization = organization_crud.get_organization( + session, self.organization_id + ) + organization = ( + Organization.model_validate(organization) if organization else None + ) + + if organization: + self.organization = organization + + return self + + def with_global_filtering(self) -> "Context": + self.use_global_filtering = True + return self + + def without_global_filtering(self) -> "Context": + self.use_global_filtering = False + return self + + def get_organization(self): + return self.organization + def get_stream_start_ms(self): return self.stream_start_ms diff --git a/src/backend/services/context.py b/src/backend/services/context.py index 48bcf99c78..75759f4f9f 100644 --- a/src/backend/services/context.py +++ b/src/backend/services/context.py @@ -1,3 +1,4 @@ +import contextvars import uuid from fastapi import Request @@ -5,6 +6,8 @@ from backend.schemas.context import Context +GLOBAL_REQUEST_CONTEXT = contextvars.ContextVar("GLOBAL_REQUEST_CONTEXT", default=None) + class ContextMiddleware: def __init__(self, app: ASGIApp): @@ -41,13 +44,23 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): agent_id = request.headers.get("Agent-Id") context.with_agent_id(agent_id) + organization_id = request.headers.get("Organization-Id", None) + context.with_organization_id(organization_id) + + context.without_global_filtering() + if organization_id: + context.with_global_filtering() + context.with_logger() # Set the context on the scope scope["context"] = context + GLOBAL_REQUEST_CONTEXT.set(context) await self.app(scope, receive, send) + # Clear the organization ID from the global context + GLOBAL_REQUEST_CONTEXT.set(None) # Clear the context after the request is complete del scope["context"] diff --git a/src/backend/services/request_validators.py b/src/backend/services/request_validators.py index 754e56c54c..c55c27b114 100644 --- a/src/backend/services/request_validators.py +++ b/src/backend/services/request_validators.py @@ -101,6 +101,26 @@ def validate_user_header(session: DBSessionDep, request: Request): raise HTTPException(status_code=401, detail="User not found.") +def validate_organization_header(session: DBSessionDep, request: Request): + """ + Validate that the request has the `Organization-Id` header, used for requests + that require an Organization. + + Args: + request (Request): The request to validate + + Raises: + HTTPException: If no `Organization-Id` header. + + """ + + organization_id = request.headers.get("Organization-Id", None) + if organization_id: + organization = organization_crud.get_organization(session, organization_id) + if not organization: + raise HTTPException(status_code=404, detail=f"Organization ID {organization_id} not found.") + + def validate_deployment_header(request: Request, session: DBSessionDep): """ Validate that the request has the `Deployment-Name` header, used for chat requests diff --git a/src/backend/tests/unit/conftest.py b/src/backend/tests/unit/conftest.py index 7486e7cd67..218fb19777 100644 --- a/src/backend/tests/unit/conftest.py +++ b/src/backend/tests/unit/conftest.py @@ -11,6 +11,7 @@ from backend.config.deployments import AVAILABLE_MODEL_DEPLOYMENTS, ModelDeploymentName from backend.database_models import get_session +from backend.database_models.base import CustomFilterQuery from backend.main import app, create_app from backend.schemas.deployment import Deployment from backend.schemas.organization import Organization @@ -47,7 +48,7 @@ def session(engine: Any) -> Generator[Session, None, None]: # Begin the nested transaction transaction = connection.begin() # Use connection within the started transaction - session = Session(bind=connection) + session = Session(bind=connection, query_cls=CustomFilterQuery) # Run Alembic migrations alembic_cfg = Config("src/backend/alembic.ini") upgrade(alembic_cfg, "head") @@ -107,7 +108,7 @@ def session_chat(engine_chat: Any) -> Generator[Session, None, None]: # Begin the nested transaction transaction = connection.begin() # Use connection within the started transaction - session = Session(bind=connection) + session = Session(bind=connection, query_cls=CustomFilterQuery) # Run Alembic migrations alembic_cfg = Config("src/backend/alembic.ini") upgrade(alembic_cfg, "head") diff --git a/src/backend/tests/unit/factories/agent.py b/src/backend/tests/unit/factories/agent.py index 022105dfa7..7d2ae9c277 100644 --- a/src/backend/tests/unit/factories/agent.py +++ b/src/backend/tests/unit/factories/agent.py @@ -12,6 +12,7 @@ class Meta: user = factory.SubFactory(UserFactory) user_id = factory.SelfAttribute("user.id") + organization_id = None name = factory.Faker("sentence") description = factory.Faker("sentence") preamble = factory.Faker("sentence") diff --git a/src/backend/tests/unit/factories/conversation.py b/src/backend/tests/unit/factories/conversation.py index 0471f2c8c4..6f064426a8 100644 --- a/src/backend/tests/unit/factories/conversation.py +++ b/src/backend/tests/unit/factories/conversation.py @@ -16,6 +16,7 @@ class Meta: updated_at = factory.Faker("date_time") text_messages = [] agent_id = None + organization_id = None class ConversationFileAssociationFactory(BaseFactory): diff --git a/src/backend/tests/unit/routers/test_agent.py b/src/backend/tests/unit/routers/test_agent.py index ad2773662d..86652eb9b4 100644 --- a/src/backend/tests/unit/routers/test_agent.py +++ b/src/backend/tests/unit/routers/test_agent.py @@ -181,6 +181,75 @@ def test_list_agents(session_client: TestClient, session: Session, user) -> None assert len(response_agents) == 3 +def test_list_organization_agents( + session_client: TestClient, + session: Session, + user, +) -> None: + session.query(Agent).delete() + organization = get_factory("Organization", session).create() + organization1 = get_factory("Organization", session).create() + for i in range(3): + _ = get_factory("Agent", session).create( + user=user, + organization_id=organization.id, + name=f"agent-{i}-{organization.id}", + ) + _ = get_factory("Agent", session).create( + user=user, organization_id=organization1.id + ) + + response = session_client.get( + "/v1/agents", headers={"User-Id": user.id, "Organization-Id": organization.id} + ) + assert response.status_code == 200 + response_agents = response.json() + agents = sorted(response_agents, key=lambda x: x["name"]) + for i in range(3): + assert agents[i]["name"] == f"agent-{i}-{organization.id}" + + +def test_list_organization_agents_query_param( + session_client: TestClient, + session: Session, + user, +) -> None: + session.query(Agent).delete() + organization = get_factory("Organization", session).create() + organization1 = get_factory("Organization", session).create() + for i in range(3): + _ = get_factory("Agent", session).create( + user=user, organization_id=organization.id + ) + _ = get_factory("Agent", session).create( + user=user, + organization_id=organization1.id, + name=f"agent-{i}-{organization1.id}", + ) + + response = session_client.get( + f"/v1/agents?organization_id={organization1.id}", + headers={"User-Id": user.id, "Organization-Id": organization.id}, + ) + assert response.status_code == 200 + response_agents = response.json() + agents = sorted(response_agents, key=lambda x: x["name"]) + for i in range(3): + assert agents[i]["name"] == f"agent-{i}-{organization1.id}" + + +def test_list_organization_agents_nonexistent_organization( + session_client: TestClient, + session: Session, + user, +) -> None: + response = session_client.get( + "/v1/agents", headers={"User-Id": user.id, "Organization-Id": "123"} + ) + assert response.status_code == 404 + assert response.json() == {"detail": "Organization ID 123 not found."} + + def test_list_private_agents( session_client: TestClient, session: Session, user ) -> None: diff --git a/src/backend/tests/unit/routers/test_chat.py b/src/backend/tests/unit/routers/test_chat.py index 6b74fefe54..fdcead9310 100644 --- a/src/backend/tests/unit/routers/test_chat.py +++ b/src/backend/tests/unit/routers/test_chat.py @@ -136,7 +136,7 @@ def test_streaming_new_chat_metrics_with_agent( def test_streaming_new_chat_with_agent( session_client_chat: TestClient, session_chat: Session, user: User ): - agent = get_factory("Agent", session_chat).create(user=user) + agent = get_factory("Agent", session_chat).create(user=user, tools=[]) deployment = get_factory("Deployment", session_chat).create() model = get_factory("Model", session_chat).create(deployment=deployment) get_factory("AgentDeploymentModel", session_chat).create( @@ -153,7 +153,7 @@ def test_streaming_new_chat_with_agent( "Deployment-Name": agent.deployment, }, params={"agent_id": agent.id}, - json={"message": "Hello", "max_tokens": 10}, + json={"message": "Hello", "max_tokens": 10, "agent_id": agent.id}, ) assert response.status_code == 200 validate_chat_streaming_response( @@ -165,7 +165,7 @@ def test_streaming_new_chat_with_agent( def test_streaming_new_chat_with_agent_existing_conversation( session_client_chat: TestClient, session_chat: Session, user: User ): - agent = get_factory("Agent", session_chat).create(user=user) + agent = get_factory("Agent", session_chat).create(user=user, tools=[]) deployment = get_factory("Deployment", session_chat).create() model = get_factory("Model", session_chat).create(deployment=deployment) get_factory("AgentDeploymentModel", session_chat).create( @@ -208,7 +208,7 @@ def test_streaming_new_chat_with_agent_existing_conversation( "Deployment-Name": agent.deployment, }, params={"agent_id": agent.id}, - json={"message": "Hello", "max_tokens": 10, "conversation_id": conversation.id}, + json={"message": "Hello", "max_tokens": 10, "conversation_id": conversation.id, "agent_id": agent.id}, ) assert response.status_code == 200 @@ -253,7 +253,7 @@ def test_streaming_chat_with_existing_conversation_from_other_agent( "Deployment-Name": ModelDeploymentName.CoherePlatform, }, params={"agent_id": agent.id}, - json={"message": "Hello", "max_tokens": 10, "conversation_id": conversation.id}, + json={"message": "Hello", "max_tokens": 10, "conversation_id": conversation.id, "agent_id": agent.id}, ) assert response.status_code == 404 diff --git a/src/backend/tests/unit/routers/test_conversation.py b/src/backend/tests/unit/routers/test_conversation.py index 692fba0367..359d01ad41 100644 --- a/src/backend/tests/unit/routers/test_conversation.py +++ b/src/backend/tests/unit/routers/test_conversation.py @@ -161,6 +161,35 @@ def test_get_conversation_lists_message_files( assert response_conversation["messages"][0]["files"][0]["id"] == file.id +def test_get_organization_conversation_list( + session_client: TestClient, + session: Session, + user: User, +) -> None: + organization = get_factory("Organization", session).create(name="test org") + organization1 = get_factory("Organization", session).create(name="test org1") + for i in range(3): + get_factory("Conversation", session).create( + user_id=user.id, + organization_id=organization.id, + description=organization.id, + ) + get_factory("Conversation", session).create( + user_id=user.id, organization_id=organization1.id + ) + + response = session_client.get( + "/v1/conversations", + headers={"User-Id": user.id, "Organization-Id": organization.id}, + ) + response_conversation = response.json() + + assert response.status_code == 200 + assert len(response_conversation) == 3 + for conversation in response_conversation: + assert conversation["description"] == organization.id + + def test_fail_get_nonexistent_conversation( session_client: TestClient, session: Session, diff --git a/src/backend/tests/unit/routers/test_organization.py b/src/backend/tests/unit/routers/test_organization.py index 70fa4d310f..658321c992 100644 --- a/src/backend/tests/unit/routers/test_organization.py +++ b/src/backend/tests/unit/routers/test_organization.py @@ -67,3 +67,18 @@ def test_delete_organization(session_client: TestClient, session: Session) -> No response = session_client.delete(f"/v1/organizations/{organization.id}") assert response.status_code == 200 assert response.json() == {} + + +def test_list_organization_users(session_client: TestClient, session: Session) -> None: + organization = get_factory("Organization", session).create(name="test organization") + for i in range(5): + user = get_factory("User", session).create(fullname=f"test user {i}") + organization.users.append(user) + + response = session_client.get(f"/v1/organizations/{organization.id}/users") + results = response.json() + assert response.status_code == 200 + assert len(results) == 5 + results = sorted(results, key=lambda x: x["fullname"]) + for i, result in enumerate(results): + assert result["fullname"] == f"test user {i}" diff --git a/src/backend/tools/files.py b/src/backend/tools/files.py index c153bdbc00..86f181ce91 100644 --- a/src/backend/tools/files.py +++ b/src/backend/tools/files.py @@ -126,7 +126,7 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: { "text": retrieved_file.file_content, "title": retrieved_file.file_name, - "url": retrieved_file.file_path, + "url": retrieved_file.file_name, } ] @@ -181,7 +181,7 @@ async def call( { "text": file.file_content, "title": file.file_name, - "url": file.file_path, + "url": file.file_name, } ) return results diff --git a/src/backend/tools/google_drive/tool.py b/src/backend/tools/google_drive/tool.py index 32e3de5372..74212bd208 100644 --- a/src/backend/tools/google_drive/tool.py +++ b/src/backend/tools/google_drive/tool.py @@ -74,7 +74,9 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: }, ) if documents.error: - raise Exception(f"Error getting documents for {query} with {documents.error}") + raise Exception( + f"Error getting documents for {query} with {documents.error}" + ) hits = documents.result["hits"] chunks = sorted( diff --git a/src/interfaces/assistants_web/src/app/(main)/(chat)/Chat.tsx b/src/interfaces/assistants_web/src/app/(main)/(chat)/Chat.tsx index 32a350336e..68fc43bbee 100644 --- a/src/interfaces/assistants_web/src/app/(main)/(chat)/Chat.tsx +++ b/src/interfaces/assistants_web/src/app/(main)/(chat)/Chat.tsx @@ -5,7 +5,7 @@ import { useEffect } from 'react'; import { Document, ManagedTool } from '@/cohere-client'; import { Conversation, ConversationError } from '@/components/Conversation'; import { TOOL_PYTHON_INTERPRETER_ID } from '@/constants'; -import { useAgent, useConversation, useListTools } from '@/hooks'; +import { useAgent, useAvailableTools, useConversation, useListTools } from '@/hooks'; import { useCitationsStore, useConversationStore, useParamsStore } from '@/stores'; import { OutputFiles } from '@/stores/slices/citationsSlice'; import { @@ -24,6 +24,7 @@ const Chat: React.FC<{ agentId?: string; conversationId?: string }> = ({ const { setConversation } = useConversationStore(); const { addCitation, saveOutputFiles } = useCitationsStore(); const { setParams, resetFileParams } = useParamsStore(); + const { availableTools } = useAvailableTools({ agent, managedTools: tools }); const { data: conversation, @@ -39,9 +40,11 @@ const Chat: React.FC<{ agentId?: string; conversationId?: string }> = ({ const agentTools = agent?.tools && - ((agent.tools + (agent.tools .map((name) => (tools ?? [])?.find((t) => t.name === name)) - .filter((t) => t !== undefined) ?? []) as ManagedTool[]); + .filter( + (t) => t !== undefined && availableTools.some((at) => at.name === t?.name) + ) as ManagedTool[]); const fileIds = conversation?.files.map((file) => file.id); @@ -53,7 +56,16 @@ const Chat: React.FC<{ agentId?: string; conversationId?: string }> = ({ if (conversationId) { setConversation({ id: conversationId }); } - }, [agent, tools, conversation]); + }, [ + agent, + tools, + conversation, + availableTools, + setParams, + resetFileParams, + setConversation, + conversationId, + ]); useEffect(() => { if (!conversation) return; diff --git a/src/interfaces/assistants_web/src/app/(main)/edit/[agentId]/page.tsx b/src/interfaces/assistants_web/src/app/(main)/edit/[agentId]/page.tsx index a41be96c39..d8445addfe 100644 --- a/src/interfaces/assistants_web/src/app/(main)/edit/[agentId]/page.tsx +++ b/src/interfaces/assistants_web/src/app/(main)/edit/[agentId]/page.tsx @@ -1,8 +1,9 @@ import { NextPage } from 'next'; -import { UpdateAgent } from './UpdateAgent'; import { getCohereServerClient } from '@/server/cohereServerClient'; +import { UpdateAgent } from './UpdateAgent'; + type Props = { params: { agentId: string; diff --git a/src/interfaces/assistants_web/src/components/AgentSettingsForm/ToolsStep.tsx b/src/interfaces/assistants_web/src/components/AgentSettingsForm/ToolsStep.tsx index 7033aa1df3..1f65fd1c0a 100644 --- a/src/interfaces/assistants_web/src/components/AgentSettingsForm/ToolsStep.tsx +++ b/src/interfaces/assistants_web/src/components/AgentSettingsForm/ToolsStep.tsx @@ -57,7 +57,7 @@ const ToolRow: React.FC<{ handleSwitch: (checked: boolean) => void; }> = ({ name, description, icon, checked, handleSwitch }) => { return ( -