diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a6491a0bc..22e1f34e7c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ These are the section headers that we use: - Improved efficiency of weak labeling when dataset contains vectors ([#3444](https://github.com/argilla-io/argilla/pull/3444)). - Added `ArgillaDatasetMixin` to detach the Argilla-related functionality from the `FeedbackDataset` ([#3427](https://github.com/argilla-io/argilla/pull/3427)) - Moved `FeedbackDataset`-related `pydantic.BaseModel` schemas to `argilla.client.feedback.schemas` instead, to be better structured and more scalable and maintainable ([#3427](https://github.com/argilla-io/argilla/pull/3427)) +- Update CLI to use database async connection ([#3450](https://github.com/argilla-io/argilla/pull/3450)). +- Update alembic code to apply migrations to use database async engine ([#3450](https://github.com/argilla-io/argilla/pull/3450)). - Limit rating questions values to the positive range [1, 10] (Closes [#3451](https://github.com/argilla-io/argilla/issues/3451)). ## [1.13.2](https://github.com/argilla-io/argilla/compare/v1.13.1...v1.13.2) diff --git a/src/argilla/__main__.py b/src/argilla/__main__.py index 56e68f12bd..d6bd2e0b42 100644 --- a/src/argilla/__main__.py +++ b/src/argilla/__main__.py @@ -14,11 +14,10 @@ # limitations under the License. -import typer +from argilla.tasks import database_app, server_app, training_app, users_app +from argilla.tasks.async_typer import AsyncTyper -from .tasks import database_app, server_app, training_app, users_app - -app = typer.Typer(rich_help_panel=True, help="Argilla CLI", no_args_is_help=True) +app = AsyncTyper(rich_help_panel=True, help="Argilla CLI", no_args_is_help=True) app.add_typer(users_app, name="users") app.add_typer(database_app, name="database") diff --git a/src/argilla/server/alembic/env.py b/src/argilla/server/alembic/env.py index b0c317cdfd..649357c35d 100644 --- a/src/argilla/server/alembic/env.py +++ b/src/argilla/server/alembic/env.py @@ -12,13 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from logging.config import fileConfig +from typing import TYPE_CHECKING from alembic import context from argilla.server.models.base import DatabaseModel from argilla.server.models.models import * # noqa from argilla.server.settings import settings -from sqlalchemy import engine_from_config, pool +from sqlalchemy import pool +from sqlalchemy.ext.asyncio import async_engine_from_config + +if TYPE_CHECKING: + from sqlalchemy import Connection # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -68,27 +74,31 @@ def run_migrations_offline() -> None: context.run_migrations() -def run_migrations_online() -> None: +def apply_migrations(connection: "Connection") -> None: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_migrations_online() -> None: """Run migrations in 'online' mode. In this scenario we need to create an Engine and associate a connection with the context. """ - connectable = engine_from_config( + connectable = async_engine_from_config( config.get_section(config.config_ini_section), prefix="sqlalchemy.", poolclass=pool.NullPool, ) - with connectable.connect() as connection: - context.configure(connection=connection, target_metadata=target_metadata) - - with context.begin_transaction(): - context.run_migrations() + async with connectable.connect() as connection: + await connection.run_sync(apply_migrations) if context.is_offline_mode(): run_migrations_offline() else: - run_migrations_online() + asyncio.run(run_migrations_online()) diff --git a/src/argilla/server/apis/v1/handlers/responses.py b/src/argilla/server/apis/v1/handlers/responses.py index cce444a71c..588c67cff7 100644 --- a/src/argilla/server/apis/v1/handlers/responses.py +++ b/src/argilla/server/apis/v1/handlers/responses.py @@ -16,10 +16,9 @@ from fastapi import APIRouter, Depends, HTTPException, Security, status from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Session from argilla.server.contexts import datasets -from argilla.server.database import get_async_db, get_db +from argilla.server.database import get_async_db from argilla.server.models import User from argilla.server.policies import ResponsePolicyV1, authorize from argilla.server.schemas.v1.responses import Response, ResponseUpdate diff --git a/src/argilla/server/database.py b/src/argilla/server/database.py index 5efd9d5273..3031736b66 100644 --- a/src/argilla/server/database.py +++ b/src/argilla/server/database.py @@ -16,16 +16,16 @@ from sqlite3 import Connection as SQLite3Connection from typing import TYPE_CHECKING, Generator -from sqlalchemy import create_engine, event +from sqlalchemy import event from sqlalchemy.engine import Engine -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine import argilla from argilla.server.settings import settings if TYPE_CHECKING: - from sqlalchemy.orm import Session + from sqlalchemy.ext.asyncio import AsyncSession + ALEMBIC_CONFIG_FILE = os.path.normpath(os.path.join(os.path.dirname(argilla.__file__), "alembic.ini")) TAGGED_REVISIONS = OrderedDict( @@ -46,21 +46,10 @@ def set_sqlite_pragma(dbapi_connection, connection_record): cursor.close() -engine = create_engine(settings.database_url) -SessionLocal = sessionmaker(autocommit=False, bind=engine) - -async_engine = create_async_engine(settings.database_url_async) +async_engine = create_async_engine(settings.database_url) AsyncSessionLocal = async_sessionmaker(autocommit=False, expire_on_commit=False, bind=async_engine) -def get_db() -> Generator["Session", None, None]: - try: - db = SessionLocal() - yield db - finally: - db.close() - - async def get_async_db() -> Generator["AsyncSession", None, None]: try: db: "AsyncSession" = AsyncSessionLocal() diff --git a/src/argilla/server/seeds.py b/src/argilla/server/seeds.py deleted file mode 100644 index e620751154..0000000000 --- a/src/argilla/server/seeds.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from argilla._constants import DEFAULT_API_KEY -from argilla.server.database import SessionLocal -from argilla.server.models import User, UserRole, Workspace - - -def development_seeds(): - with SessionLocal() as session, session.begin(): - session.add_all( - [ - Workspace(name="workspace-1"), - Workspace(name="workspace-2"), - User( - first_name="John", - last_name="Doe", - username="argilla", - role=UserRole.owner, - password_hash="$2y$05$eaw.j2Kaw8s8vpscVIZMfuqSIX3OLmxA21WjtWicDdn0losQ91Hw.", - api_key="1234", - ), - User( - first_name="Tanya", - last_name="Franklin", - username="tanya", - password_hash="$2y$05$eaw.j2Kaw8s8vpscVIZMfuqSIX3OLmxA21WjtWicDdn0losQ91Hw.", - api_key="123456", - ), - ] - ) - - -def test_seeds(db: SessionLocal): - db.add_all( - [ - Workspace( - name="argilla", - users=[ - User( - first_name="Argilla", - username="argilla", - role=UserRole.owner, - password_hash="$2y$05$eaw.j2Kaw8s8vpscVIZMfuqSIX3OLmxA21WjtWicDdn0losQ91Hw.", - api_key=DEFAULT_API_KEY, - ), - ], - ) - ] - ) - db.commit() diff --git a/src/argilla/server/settings.py b/src/argilla/server/settings.py index 4f7b208c3c..82ad432f15 100644 --- a/src/argilla/server/settings.py +++ b/src/argilla/server/settings.py @@ -18,6 +18,8 @@ """ import logging import os +import re +import warnings from pathlib import Path from typing import List, Optional from urllib.parse import urlparse @@ -125,9 +127,34 @@ def normalize_base_url(cls, base_url: str): return base_url - @validator("database_url", always=True) - def set_database_url_default(cls, database_url: str, values: dict) -> str: - return database_url or f"sqlite:///{os.path.join(values['home_path'], 'argilla.db')}?check_same_thread=False" + @validator("database_url", pre=True, always=True) + def set_database_url(cls, database_url: str, values: dict) -> str: + if not database_url: + home_path = values.get("home_path") + sqlite_file = os.path.join(home_path, "argilla.db") + return f"sqlite+aiosqlite:///{sqlite_file}?check_same_thread=False" + + if "sqlite" in database_url: + regex = re.compile(r"sqlite(?!\+aiosqlite)") + if regex.match(database_url): + warnings.warn( + "From version 1.14.0, Argilla will use `aiosqlite` as default SQLite driver. The protocol in the" + " provided database URL has been automatically replaced from `sqlite` to `sqlite+aiosqlite`." + " Please, update your database URL to use `sqlite+aiosqlite` protocol." + ) + return re.sub(regex, "sqlite+aiosqlite", database_url) + + if "postgresql" in database_url: + regex = re.compile(r"postgresql(?!\+asyncpg)(\+psycopg2)?") + if regex.match(database_url): + warnings.warn( + "From version 1.14.0, Argilla will use `asyncpg` as default PostgreSQL driver. The protocol in the" + " provided database URL has been automatically replaced from `postgresql` to `postgresql+asyncpg`." + " Please, update your database URL to use `postgresql+asyncpg` protocol." + ) + return re.sub(regex, "postgresql+asyncpg", database_url) + + return database_url @root_validator(skip_on_failure=True) def create_home_path(cls, values): @@ -165,19 +192,6 @@ def old_dataset_records_index_name(self) -> str: return index_name.replace("", "") return index_name.replace("", f".{ns}") - @property - def database_url_async(self) -> str: - if self.database_url.startswith("sqlite:///"): - return self.database_url.replace("sqlite:///", "sqlite+aiosqlite:///") - - if self.database_url.startswith("postgresql://"): - return self.database_url.replace("postgresql://", "postgresql+asyncpg://") - - if self.database_url.startswith("mysql://"): - return self.database_url.replace("mysql://", "mysql+aiomysql://") - - raise ValueError(f"Unsupported database url: '{self.database_url}'") - def obfuscated_elasticsearch(self) -> str: """Returns configured elasticsearch url obfuscating the provided password, if any""" parsed = urlparse(self.elasticsearch) diff --git a/src/argilla/tasks/async_typer.py b/src/argilla/tasks/async_typer.py new file mode 100644 index 0000000000..cac3855047 --- /dev/null +++ b/src/argilla/tasks/async_typer.py @@ -0,0 +1,57 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import sys +from functools import wraps +from typing import Any, Callable, Coroutine, TypeVar + +import typer + +if sys.version_info < (3, 10): + from typing_extensions import ParamSpec +else: + from typing import ParamSpec + + +P = ParamSpec("P") +R = TypeVar("R") + + +# https://github.com/tiangolo/typer/issues/88#issuecomment-1613013597 +class AsyncTyper(typer.Typer): + def command( + self, *args: Any, **kwargs: Any + ) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]: + super_command = super().command(*args, **kwargs) + + def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]: + @wraps(func) + def sync_func(*_args: P.args, **_kwargs: P.kwargs) -> R: + return asyncio.run(func(*_args, **_kwargs)) + + if asyncio.iscoroutinefunction(func): + super_command(sync_func) + else: + super_command(func) + + return func + + return decorator + + +def run(function: Callable[..., Coroutine[Any, Any, Any]]) -> None: + app = AsyncTyper(add_completion=False) + app.command()(function) + app() diff --git a/src/argilla/tasks/database/migrate.py b/src/argilla/tasks/database/migrate.py index 9218a51789..4d590d1d52 100644 --- a/src/argilla/tasks/database/migrate.py +++ b/src/argilla/tasks/database/migrate.py @@ -20,6 +20,7 @@ from alembic.util import CommandError from argilla.server.database import ALEMBIC_CONFIG_FILE, TAGGED_REVISIONS +from argilla.tasks import async_typer from argilla.tasks.database import utils @@ -47,4 +48,4 @@ def migrate_db(revision: Optional[str] = typer.Option(default="head", help="DB R if __name__ == "__main__": - typer.run(migrate_db) + async_typer.run(migrate_db) diff --git a/src/argilla/tasks/database/revisions.py b/src/argilla/tasks/database/revisions.py index 59309a38cf..9c5544547f 100644 --- a/src/argilla/tasks/database/revisions.py +++ b/src/argilla/tasks/database/revisions.py @@ -15,6 +15,7 @@ import typer from argilla.server.database import ALEMBIC_CONFIG_FILE, TAGGED_REVISIONS +from argilla.tasks import async_typer from argilla.tasks.database import utils @@ -39,4 +40,4 @@ def revisions(): if __name__ == "__main__": - typer.run(revisions) + async_typer.run(revisions) diff --git a/src/argilla/tasks/training/__main__.py b/src/argilla/tasks/training/__main__.py index 78f8ebc383..87ec93f1dd 100644 --- a/src/argilla/tasks/training/__main__.py +++ b/src/argilla/tasks/training/__main__.py @@ -45,7 +45,6 @@ def train( ): import json - import argilla as rg from argilla.client.api import init from argilla.training import ArgillaTrainer diff --git a/src/argilla/tasks/users/__main__.py b/src/argilla/tasks/users/__main__.py index 02b07b8370..8df083d91f 100644 --- a/src/argilla/tasks/users/__main__.py +++ b/src/argilla/tasks/users/__main__.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import typer +from argilla.tasks.async_typer import AsyncTyper from .create import create from .create_default import create_default from .migrate import migrate from .update import update -app = typer.Typer(help="Holds CLI commands for user and workspace management.", no_args_is_help=True) +app = AsyncTyper(help="Holds CLI commands for user and workspace management.", no_args_is_help=True) app.command(name="create_default", help="Creates default users and workspaces in the Argilla database.")(create_default) app.command(name="create", help="Creates a user and add it to the Argilla database.", no_args_is_help=True)(create) diff --git a/src/argilla/tasks/users/create.py b/src/argilla/tasks/users/create.py index f184b14c4e..f5da258bbd 100644 --- a/src/argilla/tasks/users/create.py +++ b/src/argilla/tasks/users/create.py @@ -16,13 +16,16 @@ import typer from pydantic import constr -from sqlalchemy.orm import Session -from typer import Typer from argilla.server.contexts import accounts -from argilla.server.database import SessionLocal +from argilla.server.database import AsyncSessionLocal from argilla.server.models import User, UserRole, Workspace -from argilla.server.security.model import USER_PASSWORD_MIN_LENGTH, UserCreate, WorkspaceCreate +from argilla.server.security.model import ( + USER_PASSWORD_MIN_LENGTH, + UserCreate, + WorkspaceCreate, +) +from argilla.tasks import async_typer from argilla.tasks.users.utils import get_or_new_workspace USER_API_KEY_MIN_LENGTH = 8 @@ -33,18 +36,14 @@ class UserCreateForTask(UserCreate): workspaces: Optional[List[WorkspaceCreate]] -def _get_or_new_workspace(session: Session, workspace_name: str): - return session.query(Workspace).filter_by(name=workspace_name).first() or Workspace(name=workspace_name) - - -def role_callback(value: str): +def role_callback(value: str) -> str: try: return UserRole(value).value except ValueError: raise typer.BadParameter("Only Camila is allowed") -def password_callback(password: str = None): +def password_callback(password: str = None) -> str: # if password is None: # raise typer.BadParameter("Password must be specified.") # if len(password) str: # if api_key and len(api_key) UserCreate: return UserCreate( first_name=user.get("full_name", ""), username=user["username"], @@ -73,35 +90,25 @@ def _build_user_create(self, user: dict): workspaces=[WorkspaceCreate(name=workspace_name) for workspace_name in self._user_workspace_names(user)], ) - def _build_user(self, session: Session, user_create: UserCreate): - return User( - first_name=user_create.first_name, - username=user_create.username, - role=user_create.role, - api_key=user_create.api_key, - password_hash=user_create.password_hash, - workspaces=[get_or_new_workspace(session, workspace.name) for workspace in user_create.workspaces], - ) - - def _user_role(self, user: dict): + def _user_role(self, user: dict) -> UserRole: if user.get("workspaces") is None: return UserRole.owner - else: - return UserRole.annotator - def _user_workspace_names(self, user: dict): + return UserRole.annotator + + def _user_workspace_names(self, user: dict) -> List[str]: workspace_names = [workspace_name for workspace_name in user.get("workspaces", [])] if user["username"] in workspace_names: return workspace_names - else: - return [user["username"]] + workspace_names + + return [user["username"]] + workspace_names -def migrate(): +async def migrate(): """Migrate users defined in YAML file to database.""" - UsersMigrator(settings.users_db_file).migrate() + await UsersMigrator(settings.users_db_file).migrate() if __name__ == "__main__": - typer.run(migrate()) + async_typer.run(migrate) diff --git a/src/argilla/tasks/users/update.py b/src/argilla/tasks/users/update.py index 2af934b62a..8882779ab4 100644 --- a/src/argilla/tasks/users/update.py +++ b/src/argilla/tasks/users/update.py @@ -15,11 +15,12 @@ import typer from argilla.server.contexts import accounts -from argilla.server.database import SessionLocal +from argilla.server.database import AsyncSessionLocal from argilla.server.models import UserRole +from argilla.tasks import async_typer -def update( +async def update( username: str = typer.Argument( default=None, help="Username as a lowercase string without spaces allowing letters, numbers, dashes and underscores.", @@ -31,8 +32,8 @@ def update( help="New role for the user.", ), ): - with SessionLocal() as session: - user = accounts.get_user_by_username_sync(session, username) + async with AsyncSessionLocal() as session: + user = await accounts.get_user_by_username(session, username) if not user: typer.echo(f"User with username {username!r} does not exists in database. Skipping...") @@ -43,10 +44,12 @@ def update( return old_role = user.role - user.role = role - session.add(user) - session.commit() + user = await user.update(session, role=role) typer.echo(f"User {username!r} successfully updated:") typer.echo(f"• role: {old_role.value!r} -> {user.role.value!r}") + + +if __name__ == "__main__": + async_typer.run(update) diff --git a/src/argilla/tasks/users/utils.py b/src/argilla/tasks/users/utils.py index 435f4b812d..4580ea095e 100644 --- a/src/argilla/tasks/users/utils.py +++ b/src/argilla/tasks/users/utils.py @@ -14,12 +14,15 @@ from typing import TYPE_CHECKING +from sqlalchemy import select + from argilla.server.models import Workspace if TYPE_CHECKING: - from sqlalchemy.orm import Session + from sqlalchemy.ext.asyncio import AsyncSession -def get_or_new_workspace(session: "Session", workspace_name: str) -> Workspace: - workspace = session.query(Workspace).filter_by(name=workspace_name).first() +async def get_or_new_workspace(session: "AsyncSession", workspace_name: str) -> Workspace: + result = await session.execute(select(Workspace).filter_by(name=workspace_name)) + workspace = result.scalar_one_or_none() return workspace or Workspace(name=workspace_name) diff --git a/tests/conftest.py b/tests/conftest.py index 933a391dd0..da4e7c5970 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -38,7 +38,7 @@ from fastapi.testclient import TestClient from opensearchpy import OpenSearch from sqlalchemy import create_engine -from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from tests.database import SyncTestSession, TestSession, set_task from tests.factories import ( @@ -54,7 +54,7 @@ from pytest_mock import MockerFixture from sqlalchemy import Connection - from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession + from sqlalchemy.ext.asyncio import AsyncConnection from sqlalchemy.orm import Session @@ -122,6 +122,25 @@ def sync_db(sync_connection: "Connection") -> Generator["Session", None, None]: sync_connection.rollback() +@pytest.fixture +def async_db_proxy(mocker: "MockerFixture", sync_db: "Session") -> "AsyncSession": + """Create a mocked `AsyncSession` that proxies to the sync session. This will allow us to execute the async CLI commands + and then in the unit test function use the sync session to assert the changes. + + Args: + mocker: pytest-mock fixture. + sync_db: Sync session. + + Returns: + Mocked `AsyncSession` that proxies to the sync session. + """ + async_session = AsyncSession() + async_session.sync_session = sync_db + async_session._proxied = sync_db + async_session.close = mocker.AsyncMock() + return async_session + + @pytest.fixture(scope="function") def mock_search_engine(mocker) -> Generator["SearchEngine", None, None]: return mocker.AsyncMock(SearchEngine) diff --git a/tests/factories.py b/tests/factories.py index 3023d7ff74..6b6f1a15bc 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -155,6 +155,13 @@ class Meta: name = factory.Sequence(lambda n: f"workspace-{n}") +class WorkspaceSyncFactory(BaseSyncFactory): + class Meta: + model = Workspace + + name = factory.Sequence(lambda n: f"workspace-{n}") + + class UserFactory(BaseFactory): class Meta: model = User diff --git a/tests/server/commons/test_settings.py b/tests/server/commons/test_settings.py index 26ff53c0aa..0de6883b95 100644 --- a/tests/server/commons/test_settings.py +++ b/tests/server/commons/test_settings.py @@ -47,3 +47,23 @@ def test_settings_default_index_replicas_with_shards_defined(monkeypatch): assert settings.es_records_index_shards == 100 assert settings.es_records_index_replicas == 0 + + +def test_settings_default_database_url(): + settings = Settings() + assert settings.database_url == f"sqlite+aiosqlite:///{settings.home_path}/argilla.db?check_same_thread=False" + + +@pytest.mark.parametrize( + "url, expected_url", + [ + ("sqlite:///test.db", "sqlite+aiosqlite:///test.db"), + ("sqlite:///:memory:", "sqlite+aiosqlite:///:memory:"), + ("postgresql://user:pass@localhost:5432/db", "postgresql+asyncpg://user:pass@localhost:5432/db"), + ("postgresql+psycopg2://user:pass@localhost:5432/db", "postgresql+asyncpg://user:pass@localhost:5432/db"), + ], +) +def test_settings_database_url(url: str, expected_url: str, monkeypatch): + monkeypatch.setenv("ARGILLA_DATABASE_URL", url) + settings = Settings() + assert settings.database_url == expected_url diff --git a/tests/tasks/conftest.py b/tests/tasks/conftest.py index 634f31e1ee..bdb01eff26 100644 --- a/tests/tasks/conftest.py +++ b/tests/tasks/conftest.py @@ -12,16 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + import pytest from argilla.__main__ import app from typer.testing import CliRunner +if TYPE_CHECKING: + from argilla.tasks.async_typer import AsyncTyper + @pytest.fixture(scope="session") -def cli_runner(): +def cli_runner() -> CliRunner: return CliRunner() @pytest.fixture(scope="session") -def cli(): +def cli() -> "AsyncTyper": return app diff --git a/tests/tasks/users/conftest.py b/tests/tasks/users/conftest.py index f57472e0fd..a4bad7b5b6 100644 --- a/tests/tasks/users/conftest.py +++ b/tests/tasks/users/conftest.py @@ -18,13 +18,12 @@ if TYPE_CHECKING: from pytest_mock import MockerFixture - from sqlalchemy.orm import Session + from sqlalchemy.ext.asyncio import AsyncSession @pytest.fixture(autouse=True) -def mock_session_local(mocker: "MockerFixture", sync_db: "Session") -> None: - sync_db.close = mocker.MagicMock() - mocker.patch("argilla.tasks.users.create.SessionLocal", return_value=sync_db) - mocker.patch("argilla.tasks.users.update.SessionLocal", return_value=sync_db) - mocker.patch("argilla.tasks.users.create_default.SessionLocal", return_value=sync_db) - mocker.patch("argilla.tasks.users.migrate.SessionLocal", return_value=sync_db) +def mock_session_local(mocker: "MockerFixture", async_db_proxy: "AsyncSession") -> None: + mocker.patch("argilla.tasks.users.create.AsyncSessionLocal", return_value=async_db_proxy) + mocker.patch("argilla.tasks.users.update.AsyncSessionLocal", return_value=async_db_proxy) + mocker.patch("argilla.tasks.users.create_default.AsyncSessionLocal", return_value=async_db_proxy) + mocker.patch("argilla.tasks.users.migrate.AsyncSessionLocal", return_value=async_db_proxy) diff --git a/tests/tasks/users/test_create.py b/tests/tasks/users/test_create.py index d45769fa09..03f13dbd40 100644 --- a/tests/tasks/users/test_create.py +++ b/tests/tasks/users/test_create.py @@ -20,7 +20,7 @@ from click.testing import CliRunner from typer import Typer -from tests.factories import UserFactory, WorkspaceFactory +from tests.factories import UserSyncFactory, WorkspaceSyncFactory if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -140,7 +140,7 @@ def test_create_with_invalid_username(sync_db: "Session", cli_runner: CliRunner, def test_create_with_existing_username(sync_db: "Session", cli_runner: CliRunner, cli: Typer): - UserFactory.create(username="username") + UserSyncFactory.create(username="username") result = cli_runner.invoke( cli, "users create --first-name first-name --username username --role owner --password 12345678" @@ -193,7 +193,7 @@ def test_create_with_invalid_api_key(sync_db: "Session", cli_runner: CliRunner, def test_create_with_existing_api_key(sync_db: "Session", cli_runner: CliRunner, cli: Typer): - UserFactory.create(api_key="abcdefgh") + UserSyncFactory.create(api_key="abcdefgh") result = cli_runner.invoke( cli, @@ -223,8 +223,8 @@ def test_create_with_multiple_workspaces(sync_db: "Session", cli_runner: CliRunn def test_create_with_existent_workspaces(sync_db: "Session", cli_runner: CliRunner, cli: Typer): - WorkspaceFactory.create(name="workspace-a") - WorkspaceFactory.create(name="workspace-b") + WorkspaceSyncFactory.create(name="workspace-a") + WorkspaceSyncFactory.create(name="workspace-b") result = cli_runner.invoke( cli, diff --git a/tests/tasks/users/test_create_default.py b/tests/tasks/users/test_create_default.py index 582228d8ef..d2cf45238d 100644 --- a/tests/tasks/users/test_create_default.py +++ b/tests/tasks/users/test_create_default.py @@ -16,14 +16,14 @@ from argilla._constants import DEFAULT_API_KEY, DEFAULT_PASSWORD, DEFAULT_USERNAME from argilla.server.contexts import accounts from argilla.server.models import User, UserRole -from click.testing import CliRunner -from typer import Typer if TYPE_CHECKING: from sqlalchemy.orm import Session + from typer import Typer + from typer.testing import CliRunner -def test_create_default(sync_db: "Session", cli_runner: CliRunner, cli: Typer): +def test_create_default(sync_db: "Session", cli_runner: "CliRunner", cli: "Typer"): result = cli_runner.invoke(cli, "users create_default") assert result.exit_code == 0 @@ -38,7 +38,7 @@ def test_create_default(sync_db: "Session", cli_runner: CliRunner, cli: Typer): assert [ws.name for ws in default_user.workspaces] == [DEFAULT_USERNAME] -def test_create_default_with_specific_api_key_and_password(sync_db: "Session", cli_runner: CliRunner, cli: Typer): +def test_create_default_with_specific_api_key_and_password(sync_db: "Session", cli_runner: "CliRunner", cli: "Typer"): result = cli_runner.invoke(cli, "users create_default --api-key my-api-key --password my-password") assert result.exit_code == 0 @@ -53,7 +53,7 @@ def test_create_default_with_specific_api_key_and_password(sync_db: "Session", c assert [ws.name for ws in default_user.workspaces] == [DEFAULT_USERNAME] -def test_create_default_quiet(sync_db: "Session", cli_runner: CliRunner, cli: Typer): +def test_create_default_quiet(sync_db: "Session", cli_runner: "CliRunner", cli: "Typer"): result = cli_runner.invoke(cli, "users create_default --quiet") assert result.exit_code == 0 @@ -61,7 +61,7 @@ def test_create_default_quiet(sync_db: "Session", cli_runner: CliRunner, cli: Ty assert sync_db.query(User).count() == 1 -def test_create_default_with_existent_default_user(sync_db: "Session", cli_runner: CliRunner, cli: Typer): +def test_create_default_with_existent_default_user(sync_db: "Session", cli_runner: "CliRunner", cli: "Typer"): result = cli_runner.invoke(cli, "users create_default") assert result.exit_code == 0 @@ -75,7 +75,7 @@ def test_create_default_with_existent_default_user(sync_db: "Session", cli_runne assert sync_db.query(User).count() == 1 -def test_create_default_with_existent_default_user_and_quiet(sync_db: "Session", cli_runner: CliRunner, cli: Typer): +def test_create_default_with_existent_default_user_and_quiet(sync_db: "Session", cli_runner: "CliRunner", cli: "Typer"): result = cli_runner.invoke(cli, "users create_default") assert result.exit_code == 0 diff --git a/tests/tasks/users/test_update.py b/tests/tasks/users/test_update.py index 83b4880095..c9e5b43c56 100644 --- a/tests/tasks/users/test_update.py +++ b/tests/tasks/users/test_update.py @@ -55,7 +55,7 @@ def test_update_with_missing_username(cli_runner: CliRunner, cli: Typer): @pytest.mark.parametrize("role_string", ["owner", "admin", "annotator"]) -def test_update_with_same_user_role(cli_runner: CliRunner, cli: Typer, role_string): +def test_update_with_same_user_role(cli_runner: CliRunner, cli: Typer, role_string: str): username = "username" UserSyncFactory.create(username=username, role=UserRole(role_string))