From e8c446b08404e8f6a608431f6aaec3e7aeb6700b Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 25 Jul 2023 09:48:43 +0200 Subject: [PATCH 01/17] feat: update CLI command to be async --- src/argilla/__main__.py | 7 ++- src/argilla/tasks/async_typer.py | 43 +++++++++++++++ src/argilla/tasks/training/__main__.py | 1 - src/argilla/tasks/users/__main__.py | 18 ++++--- src/argilla/tasks/users/create.py | 32 +++++------ src/argilla/tasks/users/create_default.py | 26 +++++---- src/argilla/tasks/users/migrate.py | 66 ++++++++++++----------- src/argilla/tasks/users/update.py | 12 ++--- src/argilla/tasks/users/utils.py | 9 ++-- 9 files changed, 129 insertions(+), 85 deletions(-) create mode 100644 src/argilla/tasks/async_typer.py diff --git a/src/argilla/__main__.py b/src/argilla/__main__.py index 56e68f12bd..d6bd2e0b42 100644 --- a/src/argilla/__main__.py +++ b/src/argilla/__main__.py @@ -14,11 +14,10 @@ # limitations under the License. -import typer +from argilla.tasks import database_app, server_app, training_app, users_app +from argilla.tasks.async_typer import AsyncTyper -from .tasks import database_app, server_app, training_app, users_app - -app = typer.Typer(rich_help_panel=True, help="Argilla CLI", no_args_is_help=True) +app = AsyncTyper(rich_help_panel=True, help="Argilla CLI", no_args_is_help=True) app.add_typer(users_app, name="users") app.add_typer(database_app, name="database") diff --git a/src/argilla/tasks/async_typer.py b/src/argilla/tasks/async_typer.py new file mode 100644 index 0000000000..ea5b78a4c2 --- /dev/null +++ b/src/argilla/tasks/async_typer.py @@ -0,0 +1,43 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from collections.abc import Callable, Coroutine +from functools import wraps +from typing import Any, ParamSpec, TypeVar + +import typer + +P = ParamSpec("P") +R = TypeVar("R") + + +# https://github.com/tiangolo/typer/issues/88#issuecomment-1613013597 +class AsyncTyper(typer.Typer): + def async_command( + self, *args: Any, **kwargs: Any + ) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]],]: + def decorator(async_func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]: + @wraps(async_func) + def sync_func(*_args: P.args, **_kwargs: P.kwargs) -> R: + return asyncio.run(async_func(*_args, **_kwargs)) + + self.command(*args, **kwargs)(sync_func) + return async_func + + return decorator + + +async def run(): + pass diff --git a/src/argilla/tasks/training/__main__.py b/src/argilla/tasks/training/__main__.py index 78f8ebc383..87ec93f1dd 100644 --- a/src/argilla/tasks/training/__main__.py +++ b/src/argilla/tasks/training/__main__.py @@ -45,7 +45,6 @@ def train( ): import json - import argilla as rg from argilla.client.api import init from argilla.training import ArgillaTrainer diff --git a/src/argilla/tasks/users/__main__.py b/src/argilla/tasks/users/__main__.py index 02b07b8370..5536039ae3 100644 --- a/src/argilla/tasks/users/__main__.py +++ b/src/argilla/tasks/users/__main__.py @@ -12,19 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -import typer +from argilla.tasks.async_typer import AsyncTyper from .create import create from .create_default import create_default from .migrate import migrate from .update import update -app = typer.Typer(help="Holds CLI commands for user and workspace management.", no_args_is_help=True) +app = AsyncTyper(help="Holds CLI commands for user and workspace management.", no_args_is_help=True) -app.command(name="create_default", help="Creates default users and workspaces in the Argilla database.")(create_default) -app.command(name="create", help="Creates a user and add it to the Argilla database.", no_args_is_help=True)(create) -app.command(name="update", help="Updates the user's role into the Argilla database.", no_args_is_help=True)(update) -app.command(name="migrate")(migrate) +app.async_command(name="create_default", help="Creates default users and workspaces in the Argilla database.")( + create_default +) +app.async_command(name="create", help="Creates a user and add it to the Argilla database.", no_args_is_help=True)( + create +) +app.async_command(name="update", help="Updates the user's role into the Argilla database.", no_args_is_help=True)( + update +) +app.async_command(name="migrate")(migrate) if __name__ == "__main__": diff --git a/src/argilla/tasks/users/create.py b/src/argilla/tasks/users/create.py index c575e2d836..c7058a01bc 100644 --- a/src/argilla/tasks/users/create.py +++ b/src/argilla/tasks/users/create.py @@ -16,12 +16,11 @@ import typer from pydantic import constr -from sqlalchemy.orm import Session from typer import Typer from argilla.server.contexts import accounts -from argilla.server.database import SessionLocal -from argilla.server.models import User, UserRole, Workspace +from argilla.server.database import AsyncSessionLocal +from argilla.server.models import User, UserRole from argilla.server.security.model import ( USER_PASSWORD_MIN_LENGTH, UserCreate, @@ -37,18 +36,14 @@ class UserCreateForTask(UserCreate): workspaces: Optional[List[WorkspaceCreate]] -def _get_or_new_workspace(session: Session, workspace_name: str): - return session.query(Workspace).filter_by(name=workspace_name).first() or Workspace(name=workspace_name) - - -def role_callback(value: str): +def role_callback(value: str) -> str: try: return UserRole(value).value except ValueError: raise typer.BadParameter("Only Camila is allowed") -def password_callback(password: str = None): +def password_callback(password: str = None) -> str: # if password is None: # raise typer.BadParameter("Password must be specified.") # if len(password) str: # if api_key and len(api_key) UserCreate: return UserCreate( first_name=user.get("full_name", ""), username=user["username"], @@ -73,34 +89,24 @@ def _build_user_create(self, user: dict): workspaces=[WorkspaceCreate(name=workspace_name) for workspace_name in self._user_workspace_names(user)], ) - def _build_user(self, session: Session, user_create: UserCreate): - return User( - first_name=user_create.first_name, - username=user_create.username, - role=user_create.role, - api_key=user_create.api_key, - password_hash=user_create.password_hash, - workspaces=[get_or_new_workspace(session, workspace.name) for workspace in user_create.workspaces], - ) - - def _user_role(self, user: dict): + def _user_role(self, user: dict) -> UserRole: if user.get("workspaces") is None: return UserRole.owner - else: - return UserRole.annotator - def _user_workspace_names(self, user: dict): + return UserRole.annotator + + def _user_workspace_names(self, user: dict) -> List[str]: workspace_names = [workspace_name for workspace_name in user.get("workspaces", [])] if user["username"] in workspace_names: return workspace_names - else: - return [user["username"]] + workspace_names + + return [user["username"]] + workspace_names -def migrate(): +async def migrate(): """Migrate users defined in YAML file to database.""" - UsersMigrator(settings.users_db_file).migrate() + await UsersMigrator(settings.users_db_file).migrate() if __name__ == "__main__": diff --git a/src/argilla/tasks/users/update.py b/src/argilla/tasks/users/update.py index 2af934b62a..ac89f6443d 100644 --- a/src/argilla/tasks/users/update.py +++ b/src/argilla/tasks/users/update.py @@ -15,11 +15,11 @@ import typer from argilla.server.contexts import accounts -from argilla.server.database import SessionLocal +from argilla.server.database import AsyncSessionLocal from argilla.server.models import UserRole -def update( +async def update( username: str = typer.Argument( default=None, help="Username as a lowercase string without spaces allowing letters, numbers, dashes and underscores.", @@ -31,8 +31,8 @@ def update( help="New role for the user.", ), ): - with SessionLocal() as session: - user = accounts.get_user_by_username_sync(session, username) + async with AsyncSessionLocal() as session: + user = await accounts.get_user_by_username(session, username) if not user: typer.echo(f"User with username {username!r} does not exists in database. Skipping...") @@ -43,10 +43,8 @@ def update( return old_role = user.role - user.role = role - session.add(user) - session.commit() + user = await user.update(session, role=role) typer.echo(f"User {username!r} successfully updated:") typer.echo(f"• role: {old_role.value!r} -> {user.role.value!r}") diff --git a/src/argilla/tasks/users/utils.py b/src/argilla/tasks/users/utils.py index 435f4b812d..4580ea095e 100644 --- a/src/argilla/tasks/users/utils.py +++ b/src/argilla/tasks/users/utils.py @@ -14,12 +14,15 @@ from typing import TYPE_CHECKING +from sqlalchemy import select + from argilla.server.models import Workspace if TYPE_CHECKING: - from sqlalchemy.orm import Session + from sqlalchemy.ext.asyncio import AsyncSession -def get_or_new_workspace(session: "Session", workspace_name: str) -> Workspace: - workspace = session.query(Workspace).filter_by(name=workspace_name).first() +async def get_or_new_workspace(session: "AsyncSession", workspace_name: str) -> Workspace: + result = await session.execute(select(Workspace).filter_by(name=workspace_name)) + workspace = result.scalar_one_or_none() return workspace or Workspace(name=workspace_name) From eb451c3405e134c7531368aa1faca9b7cdefe0ba Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 25 Jul 2023 09:49:03 +0200 Subject: [PATCH 02/17] feat: update CLI commands after async update --- tests/factories.py | 7 ++++++ tests/tasks/conftest.py | 9 +++++-- tests/tasks/users/conftest.py | 31 +++++++++++++++++++----- tests/tasks/users/test_create.py | 10 ++++---- tests/tasks/users/test_create_default.py | 14 +++++------ tests/tasks/users/test_update.py | 2 +- 6 files changed, 52 insertions(+), 21 deletions(-) diff --git a/tests/factories.py b/tests/factories.py index 3023d7ff74..6b6f1a15bc 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -155,6 +155,13 @@ class Meta: name = factory.Sequence(lambda n: f"workspace-{n}") +class WorkspaceSyncFactory(BaseSyncFactory): + class Meta: + model = Workspace + + name = factory.Sequence(lambda n: f"workspace-{n}") + + class UserFactory(BaseFactory): class Meta: model = User diff --git a/tests/tasks/conftest.py b/tests/tasks/conftest.py index 634f31e1ee..bdb01eff26 100644 --- a/tests/tasks/conftest.py +++ b/tests/tasks/conftest.py @@ -12,16 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + import pytest from argilla.__main__ import app from typer.testing import CliRunner +if TYPE_CHECKING: + from argilla.tasks.async_typer import AsyncTyper + @pytest.fixture(scope="session") -def cli_runner(): +def cli_runner() -> CliRunner: return CliRunner() @pytest.fixture(scope="session") -def cli(): +def cli() -> "AsyncTyper": return app diff --git a/tests/tasks/users/conftest.py b/tests/tasks/users/conftest.py index f57472e0fd..51679ea23d 100644 --- a/tests/tasks/users/conftest.py +++ b/tests/tasks/users/conftest.py @@ -15,16 +15,35 @@ from typing import TYPE_CHECKING import pytest +from sqlalchemy.ext.asyncio import AsyncSession if TYPE_CHECKING: from pytest_mock import MockerFixture from sqlalchemy.orm import Session +@pytest.fixture +def async_db_proxy(mocker: "MockerFixture", sync_db: "Session") -> "AsyncSession": + """Create a mocked `AsyncSession` that proxies to the sync session. This will allow us to execute the async CLI commands + and then in the unit test function use the sync session to assert the changes. + + Args: + mocker: pytest-mock fixture. + sync_db: Sync session. + + Returns: + Mocked `AsyncSession` that proxies to the sync session. + """ + async_session = AsyncSession() + async_session.sync_session = sync_db + async_session._proxied = sync_db + async_session.close = mocker.AsyncMock() + return async_session + + @pytest.fixture(autouse=True) -def mock_session_local(mocker: "MockerFixture", sync_db: "Session") -> None: - sync_db.close = mocker.MagicMock() - mocker.patch("argilla.tasks.users.create.SessionLocal", return_value=sync_db) - mocker.patch("argilla.tasks.users.update.SessionLocal", return_value=sync_db) - mocker.patch("argilla.tasks.users.create_default.SessionLocal", return_value=sync_db) - mocker.patch("argilla.tasks.users.migrate.SessionLocal", return_value=sync_db) +def mock_session_local(mocker: "MockerFixture", async_db_proxy: "AsyncSession") -> None: + mocker.patch("argilla.tasks.users.create.AsyncSessionLocal", return_value=async_db_proxy) + mocker.patch("argilla.tasks.users.update.AsyncSessionLocal", return_value=async_db_proxy) + mocker.patch("argilla.tasks.users.create_default.AsyncSessionLocal", return_value=async_db_proxy) + mocker.patch("argilla.tasks.users.migrate.AsyncSessionLocal", return_value=async_db_proxy) diff --git a/tests/tasks/users/test_create.py b/tests/tasks/users/test_create.py index d45769fa09..03f13dbd40 100644 --- a/tests/tasks/users/test_create.py +++ b/tests/tasks/users/test_create.py @@ -20,7 +20,7 @@ from click.testing import CliRunner from typer import Typer -from tests.factories import UserFactory, WorkspaceFactory +from tests.factories import UserSyncFactory, WorkspaceSyncFactory if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -140,7 +140,7 @@ def test_create_with_invalid_username(sync_db: "Session", cli_runner: CliRunner, def test_create_with_existing_username(sync_db: "Session", cli_runner: CliRunner, cli: Typer): - UserFactory.create(username="username") + UserSyncFactory.create(username="username") result = cli_runner.invoke( cli, "users create --first-name first-name --username username --role owner --password 12345678" @@ -193,7 +193,7 @@ def test_create_with_invalid_api_key(sync_db: "Session", cli_runner: CliRunner, def test_create_with_existing_api_key(sync_db: "Session", cli_runner: CliRunner, cli: Typer): - UserFactory.create(api_key="abcdefgh") + UserSyncFactory.create(api_key="abcdefgh") result = cli_runner.invoke( cli, @@ -223,8 +223,8 @@ def test_create_with_multiple_workspaces(sync_db: "Session", cli_runner: CliRunn def test_create_with_existent_workspaces(sync_db: "Session", cli_runner: CliRunner, cli: Typer): - WorkspaceFactory.create(name="workspace-a") - WorkspaceFactory.create(name="workspace-b") + WorkspaceSyncFactory.create(name="workspace-a") + WorkspaceSyncFactory.create(name="workspace-b") result = cli_runner.invoke( cli, diff --git a/tests/tasks/users/test_create_default.py b/tests/tasks/users/test_create_default.py index 582228d8ef..d2cf45238d 100644 --- a/tests/tasks/users/test_create_default.py +++ b/tests/tasks/users/test_create_default.py @@ -16,14 +16,14 @@ from argilla._constants import DEFAULT_API_KEY, DEFAULT_PASSWORD, DEFAULT_USERNAME from argilla.server.contexts import accounts from argilla.server.models import User, UserRole -from click.testing import CliRunner -from typer import Typer if TYPE_CHECKING: from sqlalchemy.orm import Session + from typer import Typer + from typer.testing import CliRunner -def test_create_default(sync_db: "Session", cli_runner: CliRunner, cli: Typer): +def test_create_default(sync_db: "Session", cli_runner: "CliRunner", cli: "Typer"): result = cli_runner.invoke(cli, "users create_default") assert result.exit_code == 0 @@ -38,7 +38,7 @@ def test_create_default(sync_db: "Session", cli_runner: CliRunner, cli: Typer): assert [ws.name for ws in default_user.workspaces] == [DEFAULT_USERNAME] -def test_create_default_with_specific_api_key_and_password(sync_db: "Session", cli_runner: CliRunner, cli: Typer): +def test_create_default_with_specific_api_key_and_password(sync_db: "Session", cli_runner: "CliRunner", cli: "Typer"): result = cli_runner.invoke(cli, "users create_default --api-key my-api-key --password my-password") assert result.exit_code == 0 @@ -53,7 +53,7 @@ def test_create_default_with_specific_api_key_and_password(sync_db: "Session", c assert [ws.name for ws in default_user.workspaces] == [DEFAULT_USERNAME] -def test_create_default_quiet(sync_db: "Session", cli_runner: CliRunner, cli: Typer): +def test_create_default_quiet(sync_db: "Session", cli_runner: "CliRunner", cli: "Typer"): result = cli_runner.invoke(cli, "users create_default --quiet") assert result.exit_code == 0 @@ -61,7 +61,7 @@ def test_create_default_quiet(sync_db: "Session", cli_runner: CliRunner, cli: Ty assert sync_db.query(User).count() == 1 -def test_create_default_with_existent_default_user(sync_db: "Session", cli_runner: CliRunner, cli: Typer): +def test_create_default_with_existent_default_user(sync_db: "Session", cli_runner: "CliRunner", cli: "Typer"): result = cli_runner.invoke(cli, "users create_default") assert result.exit_code == 0 @@ -75,7 +75,7 @@ def test_create_default_with_existent_default_user(sync_db: "Session", cli_runne assert sync_db.query(User).count() == 1 -def test_create_default_with_existent_default_user_and_quiet(sync_db: "Session", cli_runner: CliRunner, cli: Typer): +def test_create_default_with_existent_default_user_and_quiet(sync_db: "Session", cli_runner: "CliRunner", cli: "Typer"): result = cli_runner.invoke(cli, "users create_default") assert result.exit_code == 0 diff --git a/tests/tasks/users/test_update.py b/tests/tasks/users/test_update.py index 83b4880095..c9e5b43c56 100644 --- a/tests/tasks/users/test_update.py +++ b/tests/tasks/users/test_update.py @@ -55,7 +55,7 @@ def test_update_with_missing_username(cli_runner: CliRunner, cli: Typer): @pytest.mark.parametrize("role_string", ["owner", "admin", "annotator"]) -def test_update_with_same_user_role(cli_runner: CliRunner, cli: Typer, role_string): +def test_update_with_same_user_role(cli_runner: CliRunner, cli: Typer, role_string: str): username = "username" UserSyncFactory.create(username=username, role=UserRole(role_string)) From dc94c39bf9f32fc002e6a825b745b82fc3b9becc Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 25 Jul 2023 10:15:09 +0200 Subject: [PATCH 03/17] fix: `ParamSpec` import --- src/argilla/tasks/async_typer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/argilla/tasks/async_typer.py b/src/argilla/tasks/async_typer.py index ea5b78a4c2..0f3e733771 100644 --- a/src/argilla/tasks/async_typer.py +++ b/src/argilla/tasks/async_typer.py @@ -13,9 +13,16 @@ # limitations under the License. import asyncio +import sys from collections.abc import Callable, Coroutine from functools import wraps -from typing import Any, ParamSpec, TypeVar +from typing import Any, TypeVar + +if sys.version_info < (3, 10): + from typing_extensions import ParamSpec +else: + from typing import ParamSpec + import typer From fbd056579e14507f6e63031264912214d8975230 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 25 Jul 2023 10:21:46 +0200 Subject: [PATCH 04/17] feat: add `run` function for `async` commands --- src/argilla/tasks/async_typer.py | 6 ++++-- src/argilla/tasks/users/create.py | 6 ++---- src/argilla/tasks/users/create_default.py | 3 ++- src/argilla/tasks/users/migrate.py | 3 ++- src/argilla/tasks/users/update.py | 5 +++++ 5 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/argilla/tasks/async_typer.py b/src/argilla/tasks/async_typer.py index 0f3e733771..daedcb2c2c 100644 --- a/src/argilla/tasks/async_typer.py +++ b/src/argilla/tasks/async_typer.py @@ -46,5 +46,7 @@ def sync_func(*_args: P.args, **_kwargs: P.kwargs) -> R: return decorator -async def run(): - pass +def run(function: Callable[..., Coroutine[Any, Any, Any]]) -> None: + app = AsyncTyper(add_completion=False) + app.async_command()(function) + app() diff --git a/src/argilla/tasks/users/create.py b/src/argilla/tasks/users/create.py index c7058a01bc..d50273ecee 100644 --- a/src/argilla/tasks/users/create.py +++ b/src/argilla/tasks/users/create.py @@ -16,7 +16,6 @@ import typer from pydantic import constr -from typer import Typer from argilla.server.contexts import accounts from argilla.server.database import AsyncSessionLocal @@ -26,6 +25,7 @@ UserCreate, WorkspaceCreate, ) +from argilla.tasks import async_typer from argilla.tasks.users.utils import get_or_new_workspace USER_API_KEY_MIN_LENGTH = 8 @@ -128,6 +128,4 @@ async def create( if __name__ == "__main__": - app = Typer(add_completion=False) - app.command(no_args_is_help=True)(create) - app() + async_typer.run(create) diff --git a/src/argilla/tasks/users/create_default.py b/src/argilla/tasks/users/create_default.py index 95af236e5f..88c0caff0f 100644 --- a/src/argilla/tasks/users/create_default.py +++ b/src/argilla/tasks/users/create_default.py @@ -18,6 +18,7 @@ from argilla.server.contexts import accounts from argilla.server.database import AsyncSessionLocal from argilla.server.models import User, UserRole, Workspace +from argilla.tasks import async_typer async def create_default( @@ -52,4 +53,4 @@ async def create_default( if __name__ == "__main__": - typer.run(create_default) + async_typer.run(create_default) diff --git a/src/argilla/tasks/users/migrate.py b/src/argilla/tasks/users/migrate.py index 6f4ed48b5a..0467d17c42 100644 --- a/src/argilla/tasks/users/migrate.py +++ b/src/argilla/tasks/users/migrate.py @@ -22,6 +22,7 @@ from argilla.server.models import User, UserRole from argilla.server.security.auth_provider.local.settings import settings from argilla.server.security.model import USER_USERNAME_REGEX, WORKSPACE_NAME_REGEX +from argilla.tasks import async_typer from argilla.tasks.users.utils import get_or_new_workspace if TYPE_CHECKING: @@ -110,4 +111,4 @@ async def migrate(): if __name__ == "__main__": - typer.run(migrate()) + async_typer.run(migrate) diff --git a/src/argilla/tasks/users/update.py b/src/argilla/tasks/users/update.py index ac89f6443d..8882779ab4 100644 --- a/src/argilla/tasks/users/update.py +++ b/src/argilla/tasks/users/update.py @@ -17,6 +17,7 @@ from argilla.server.contexts import accounts from argilla.server.database import AsyncSessionLocal from argilla.server.models import UserRole +from argilla.tasks import async_typer async def update( @@ -48,3 +49,7 @@ async def update( typer.echo(f"User {username!r} successfully updated:") typer.echo(f"• role: {old_role.value!r} -> {user.role.value!r}") + + +if __name__ == "__main__": + async_typer.run(update) From 0a7e82508c9c9ccdb1e6a852ad6f3e517ec7f135 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 25 Jul 2023 10:38:34 +0200 Subject: [PATCH 05/17] fix: `async_command` return type hint --- src/argilla/tasks/async_typer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/argilla/tasks/async_typer.py b/src/argilla/tasks/async_typer.py index daedcb2c2c..1d1faeeabf 100644 --- a/src/argilla/tasks/async_typer.py +++ b/src/argilla/tasks/async_typer.py @@ -14,9 +14,10 @@ import asyncio import sys -from collections.abc import Callable, Coroutine from functools import wraps -from typing import Any, TypeVar +from typing import Any, Callable, Coroutine, TypeVar + +import typer if sys.version_info < (3, 10): from typing_extensions import ParamSpec @@ -24,8 +25,6 @@ from typing import ParamSpec -import typer - P = ParamSpec("P") R = TypeVar("R") @@ -34,7 +33,7 @@ class AsyncTyper(typer.Typer): def async_command( self, *args: Any, **kwargs: Any - ) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]],]: + ) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]: def decorator(async_func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]: @wraps(async_func) def sync_func(*_args: P.args, **_kwargs: P.kwargs) -> R: From 9ae14251fad54a96393c6d59165fd91a2c76f7c1 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 25 Jul 2023 11:58:07 +0200 Subject: [PATCH 06/17] feat: remove `seeds` module --- src/argilla/server/seeds.py | 62 ------------------------------------- 1 file changed, 62 deletions(-) delete mode 100644 src/argilla/server/seeds.py diff --git a/src/argilla/server/seeds.py b/src/argilla/server/seeds.py deleted file mode 100644 index e620751154..0000000000 --- a/src/argilla/server/seeds.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from argilla._constants import DEFAULT_API_KEY -from argilla.server.database import SessionLocal -from argilla.server.models import User, UserRole, Workspace - - -def development_seeds(): - with SessionLocal() as session, session.begin(): - session.add_all( - [ - Workspace(name="workspace-1"), - Workspace(name="workspace-2"), - User( - first_name="John", - last_name="Doe", - username="argilla", - role=UserRole.owner, - password_hash="$2y$05$eaw.j2Kaw8s8vpscVIZMfuqSIX3OLmxA21WjtWicDdn0losQ91Hw.", - api_key="1234", - ), - User( - first_name="Tanya", - last_name="Franklin", - username="tanya", - password_hash="$2y$05$eaw.j2Kaw8s8vpscVIZMfuqSIX3OLmxA21WjtWicDdn0losQ91Hw.", - api_key="123456", - ), - ] - ) - - -def test_seeds(db: SessionLocal): - db.add_all( - [ - Workspace( - name="argilla", - users=[ - User( - first_name="Argilla", - username="argilla", - role=UserRole.owner, - password_hash="$2y$05$eaw.j2Kaw8s8vpscVIZMfuqSIX3OLmxA21WjtWicDdn0losQ91Hw.", - api_key=DEFAULT_API_KEY, - ), - ], - ) - ] - ) - db.commit() From 9d1230ad24b244b026aa8a48e999ba17992758e1 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 25 Jul 2023 12:00:03 +0200 Subject: [PATCH 07/17] feat: remove sync engine, sessiond and `get_db` function --- .../server/apis/v1/handlers/responses.py | 3 +-- src/argilla/server/database.py | 19 ++++--------------- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/src/argilla/server/apis/v1/handlers/responses.py b/src/argilla/server/apis/v1/handlers/responses.py index cce444a71c..588c67cff7 100644 --- a/src/argilla/server/apis/v1/handlers/responses.py +++ b/src/argilla/server/apis/v1/handlers/responses.py @@ -16,10 +16,9 @@ from fastapi import APIRouter, Depends, HTTPException, Security, status from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Session from argilla.server.contexts import datasets -from argilla.server.database import get_async_db, get_db +from argilla.server.database import get_async_db from argilla.server.models import User from argilla.server.policies import ResponsePolicyV1, authorize from argilla.server.schemas.v1.responses import Response, ResponseUpdate diff --git a/src/argilla/server/database.py b/src/argilla/server/database.py index 33dff80f95..588512474c 100644 --- a/src/argilla/server/database.py +++ b/src/argilla/server/database.py @@ -16,16 +16,16 @@ from sqlite3 import Connection as SQLite3Connection from typing import TYPE_CHECKING, Generator -from sqlalchemy import create_engine, event +from sqlalchemy import event from sqlalchemy.engine import Engine -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine import argilla from argilla.server.settings import settings if TYPE_CHECKING: - from sqlalchemy.orm import Session + from sqlalchemy.ext.asyncio import AsyncSession + ALEMBIC_CONFIG_FILE = os.path.normpath(os.path.join(os.path.dirname(argilla.__file__), "alembic.ini")) TAGGED_REVISIONS = OrderedDict( @@ -46,21 +46,10 @@ def set_sqlite_pragma(dbapi_connection, connection_record): cursor.close() -engine = create_engine(settings.database_url) -SessionLocal = sessionmaker(autocommit=False, bind=engine) - async_engine = create_async_engine(settings.database_url_async) AsyncSessionLocal = async_sessionmaker(autocommit=False, expire_on_commit=False, bind=async_engine) -def get_db() -> Generator["Session", None, None]: - try: - db = SessionLocal() - yield db - finally: - db.close() - - async def get_async_db() -> Generator["AsyncSession", None, None]: try: db: "AsyncSession" = AsyncSessionLocal() From fc4ff0405e91cd8532f239b04a2870687542f786 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 25 Jul 2023 12:56:11 +0200 Subject: [PATCH 08/17] feat: remove `database_url_async` property --- src/argilla/server/database.py | 2 +- src/argilla/server/settings.py | 17 +++-------------- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/src/argilla/server/database.py b/src/argilla/server/database.py index 588512474c..f7b67be9d7 100644 --- a/src/argilla/server/database.py +++ b/src/argilla/server/database.py @@ -46,7 +46,7 @@ def set_sqlite_pragma(dbapi_connection, connection_record): cursor.close() -async_engine = create_async_engine(settings.database_url_async) +async_engine = create_async_engine(settings.database_url) AsyncSessionLocal = async_sessionmaker(autocommit=False, expire_on_commit=False, bind=async_engine) diff --git a/src/argilla/server/settings.py b/src/argilla/server/settings.py index 4f7b208c3c..90f8fbe2bd 100644 --- a/src/argilla/server/settings.py +++ b/src/argilla/server/settings.py @@ -127,7 +127,9 @@ def normalize_base_url(cls, base_url: str): @validator("database_url", always=True) def set_database_url_default(cls, database_url: str, values: dict) -> str: - return database_url or f"sqlite:///{os.path.join(values['home_path'], 'argilla.db')}?check_same_thread=False" + home_path = values.get("home_path") + sqlite_file = os.path.join(home_path, "argilla.db") + return database_url or f"sqlite+aiosqlite:///{sqlite_file}?check_same_thread=False" @root_validator(skip_on_failure=True) def create_home_path(cls, values): @@ -165,19 +167,6 @@ def old_dataset_records_index_name(self) -> str: return index_name.replace("", "") return index_name.replace("", f".{ns}") - @property - def database_url_async(self) -> str: - if self.database_url.startswith("sqlite:///"): - return self.database_url.replace("sqlite:///", "sqlite+aiosqlite:///") - - if self.database_url.startswith("postgresql://"): - return self.database_url.replace("postgresql://", "postgresql+asyncpg://") - - if self.database_url.startswith("mysql://"): - return self.database_url.replace("mysql://", "mysql+aiomysql://") - - raise ValueError(f"Unsupported database url: '{self.database_url}'") - def obfuscated_elasticsearch(self) -> str: """Returns configured elasticsearch url obfuscating the provided password, if any""" parsed = urlparse(self.elasticsearch) From f2d24bc1e96c5d183ad160086c91d30323660946 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 25 Jul 2023 14:17:19 +0200 Subject: [PATCH 09/17] feat: use async engine to run migrations --- src/argilla/server/alembic/env.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/argilla/server/alembic/env.py b/src/argilla/server/alembic/env.py index b0c317cdfd..67dc9f62a1 100644 --- a/src/argilla/server/alembic/env.py +++ b/src/argilla/server/alembic/env.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from logging.config import fileConfig from alembic import context from argilla.server.models.base import DatabaseModel from argilla.server.models.models import * # noqa from argilla.server.settings import settings -from sqlalchemy import engine_from_config, pool +from sqlalchemy import pool +from sqlalchemy.ext.asyncio import async_engine_from_config # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -68,27 +70,31 @@ def run_migrations_offline() -> None: context.run_migrations() -def run_migrations_online() -> None: +def apply_migrations(connection) -> None: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_migrations_online() -> None: """Run migrations in 'online' mode. In this scenario we need to create an Engine and associate a connection with the context. """ - connectable = engine_from_config( + connectable = async_engine_from_config( config.get_section(config.config_ini_section), prefix="sqlalchemy.", poolclass=pool.NullPool, ) - with connectable.connect() as connection: - context.configure(connection=connection, target_metadata=target_metadata) - - with context.begin_transaction(): - context.run_migrations() + async with connectable.connect() as connection: + await connection.run_sync(apply_migrations) if context.is_offline_mode(): run_migrations_offline() else: - run_migrations_online() + asyncio.run(run_migrations_online()) From e200d8e4cf60b34d8a16bc97ebb87e0bff186f03 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 25 Jul 2023 14:38:07 +0200 Subject: [PATCH 10/17] docs: update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4634d26659..03e8ed495d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ These are the section headers that we use: - Improved efficiency of weak labeling when dataset contains vectors ([#3444](https://github.com/argilla-io/argilla/pull/3444)). - Added `ArgillaDatasetMixin` to detach the Argilla-related functionality from the `FeedbackDataset` ([#3427](https://github.com/argilla-io/argilla/pull/3427)) - Moved `FeedbackDataset`-related `pydantic.BaseModel` schemas to `argilla.client.feedback.schemas` instead, to be better structured and more scalable and maintainable ([#3427](https://github.com/argilla-io/argilla/pull/3427)) +- Update CLI to use database async connection ([#3450](https://github.com/argilla-io/argilla/pull/3450)). +- Update alembic code to apply migrations to use database async engine ([#3450](https://github.com/argilla-io/argilla/pull/3450)). ## [1.13.2](https://github.com/argilla-io/argilla/compare/v1.13.1...v1.13.2) From 882c29a22dcb309b71ab37be10f2e2594a1d12df Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 25 Jul 2023 15:00:00 +0200 Subject: [PATCH 11/17] feat: add `connection` type hint --- src/argilla/server/alembic/env.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/argilla/server/alembic/env.py b/src/argilla/server/alembic/env.py index 67dc9f62a1..649357c35d 100644 --- a/src/argilla/server/alembic/env.py +++ b/src/argilla/server/alembic/env.py @@ -14,6 +14,7 @@ import asyncio from logging.config import fileConfig +from typing import TYPE_CHECKING from alembic import context from argilla.server.models.base import DatabaseModel @@ -22,6 +23,9 @@ from sqlalchemy import pool from sqlalchemy.ext.asyncio import async_engine_from_config +if TYPE_CHECKING: + from sqlalchemy import Connection + # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config @@ -70,7 +74,7 @@ def run_migrations_offline() -> None: context.run_migrations() -def apply_migrations(connection) -> None: +def apply_migrations(connection: "Connection") -> None: context.configure(connection=connection, target_metadata=target_metadata) with context.begin_transaction(): From b92669a409258d7a25f1a741c0de9431400101fc Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 25 Jul 2023 17:08:16 +0200 Subject: [PATCH 12/17] feat: rename to `command` --- src/argilla/tasks/async_typer.py | 20 +++++++++++++------- src/argilla/tasks/database/migrate.py | 3 ++- src/argilla/tasks/database/revisions.py | 3 ++- src/argilla/tasks/users/__main__.py | 14 ++++---------- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/src/argilla/tasks/async_typer.py b/src/argilla/tasks/async_typer.py index 1d1faeeabf..cac3855047 100644 --- a/src/argilla/tasks/async_typer.py +++ b/src/argilla/tasks/async_typer.py @@ -31,21 +31,27 @@ # https://github.com/tiangolo/typer/issues/88#issuecomment-1613013597 class AsyncTyper(typer.Typer): - def async_command( + def command( self, *args: Any, **kwargs: Any ) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]: - def decorator(async_func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]: - @wraps(async_func) + super_command = super().command(*args, **kwargs) + + def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]: + @wraps(func) def sync_func(*_args: P.args, **_kwargs: P.kwargs) -> R: - return asyncio.run(async_func(*_args, **_kwargs)) + return asyncio.run(func(*_args, **_kwargs)) + + if asyncio.iscoroutinefunction(func): + super_command(sync_func) + else: + super_command(func) - self.command(*args, **kwargs)(sync_func) - return async_func + return func return decorator def run(function: Callable[..., Coroutine[Any, Any, Any]]) -> None: app = AsyncTyper(add_completion=False) - app.async_command()(function) + app.command()(function) app() diff --git a/src/argilla/tasks/database/migrate.py b/src/argilla/tasks/database/migrate.py index 9218a51789..4d590d1d52 100644 --- a/src/argilla/tasks/database/migrate.py +++ b/src/argilla/tasks/database/migrate.py @@ -20,6 +20,7 @@ from alembic.util import CommandError from argilla.server.database import ALEMBIC_CONFIG_FILE, TAGGED_REVISIONS +from argilla.tasks import async_typer from argilla.tasks.database import utils @@ -47,4 +48,4 @@ def migrate_db(revision: Optional[str] = typer.Option(default="head", help="DB R if __name__ == "__main__": - typer.run(migrate_db) + async_typer.run(migrate_db) diff --git a/src/argilla/tasks/database/revisions.py b/src/argilla/tasks/database/revisions.py index 59309a38cf..9c5544547f 100644 --- a/src/argilla/tasks/database/revisions.py +++ b/src/argilla/tasks/database/revisions.py @@ -15,6 +15,7 @@ import typer from argilla.server.database import ALEMBIC_CONFIG_FILE, TAGGED_REVISIONS +from argilla.tasks import async_typer from argilla.tasks.database import utils @@ -39,4 +40,4 @@ def revisions(): if __name__ == "__main__": - typer.run(revisions) + async_typer.run(revisions) diff --git a/src/argilla/tasks/users/__main__.py b/src/argilla/tasks/users/__main__.py index 5536039ae3..8df083d91f 100644 --- a/src/argilla/tasks/users/__main__.py +++ b/src/argilla/tasks/users/__main__.py @@ -21,16 +21,10 @@ app = AsyncTyper(help="Holds CLI commands for user and workspace management.", no_args_is_help=True) -app.async_command(name="create_default", help="Creates default users and workspaces in the Argilla database.")( - create_default -) -app.async_command(name="create", help="Creates a user and add it to the Argilla database.", no_args_is_help=True)( - create -) -app.async_command(name="update", help="Updates the user's role into the Argilla database.", no_args_is_help=True)( - update -) -app.async_command(name="migrate")(migrate) +app.command(name="create_default", help="Creates default users and workspaces in the Argilla database.")(create_default) +app.command(name="create", help="Creates a user and add it to the Argilla database.", no_args_is_help=True)(create) +app.command(name="update", help="Updates the user's role into the Argilla database.", no_args_is_help=True)(update) +app.command(name="migrate")(migrate) if __name__ == "__main__": From 0b6b88fa30449ec68c5221096b0a937192fa9410 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 25 Jul 2023 17:28:36 +0200 Subject: [PATCH 13/17] feat: add backward compatibility for `database_url` --- src/argilla/server/settings.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/src/argilla/server/settings.py b/src/argilla/server/settings.py index 90f8fbe2bd..a294899119 100644 --- a/src/argilla/server/settings.py +++ b/src/argilla/server/settings.py @@ -18,6 +18,8 @@ """ import logging import os +import re +import warnings from pathlib import Path from typing import List, Optional from urllib.parse import urlparse @@ -131,6 +133,35 @@ def set_database_url_default(cls, database_url: str, values: dict) -> str: sqlite_file = os.path.join(home_path, "argilla.db") return database_url or f"sqlite+aiosqlite:///{sqlite_file}?check_same_thread=False" + @validator("database_url", pre=True) + def set_database_url(cls, database_url: str, values: dict) -> str: + if not database_url: + home_path = values.get("home_path") + sqlite_file = os.path.join(home_path, "argilla.db") + return f"sqlite+aiosqlite:///{sqlite_file}?check_same_thread=False" + + if "sqlite" in database_url: + regex = re.compile(r"sqlite(?!\+aiosqlite)") + if regex.match(database_url): + warnings.warn( + "From version 1.14.0, Argilla will use `aiosqlite` as default SQLite driver. The protocol in the" + " provided database URL has been automatically replaced from `sqlite` to `sqlite+aiosqlite`." + " Please, update your database URL to use `sqlite+aiosqlite` protocol." + ) + return re.sub(regex, "sqlite+aiosqlite", database_url) + + if "postgresql" in database_url: + regex = re.compile(r"postgresql(?!\+asyncpg)(\+psycopg2)?") + if regex.match(database_url): + warnings.warn( + "From version 1.14.0, Argilla will use `asyncpg` as default PostgreSQL driver. The protocol in the" + " provided database URL has been automatically replaced from `postgresql` to `postgresql+asyncpg`." + " Please, update your database URL to use `postgresql+asyncpg` protocol." + ) + return re.sub(regex, "postgresql+asyncpg", database_url) + + return database_url + @root_validator(skip_on_failure=True) def create_home_path(cls, values): Path(values["home_path"]).mkdir(parents=True, exist_ok=True) From 7b00f468b7c116aad04323ca1ca382e702931ec4 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 25 Jul 2023 17:33:16 +0200 Subject: [PATCH 14/17] feat: add `Settings.database_url` unit tests --- tests/server/commons/test_settings.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/server/commons/test_settings.py b/tests/server/commons/test_settings.py index 26ff53c0aa..0de6883b95 100644 --- a/tests/server/commons/test_settings.py +++ b/tests/server/commons/test_settings.py @@ -47,3 +47,23 @@ def test_settings_default_index_replicas_with_shards_defined(monkeypatch): assert settings.es_records_index_shards == 100 assert settings.es_records_index_replicas == 0 + + +def test_settings_default_database_url(): + settings = Settings() + assert settings.database_url == f"sqlite+aiosqlite:///{settings.home_path}/argilla.db?check_same_thread=False" + + +@pytest.mark.parametrize( + "url, expected_url", + [ + ("sqlite:///test.db", "sqlite+aiosqlite:///test.db"), + ("sqlite:///:memory:", "sqlite+aiosqlite:///:memory:"), + ("postgresql://user:pass@localhost:5432/db", "postgresql+asyncpg://user:pass@localhost:5432/db"), + ("postgresql+psycopg2://user:pass@localhost:5432/db", "postgresql+asyncpg://user:pass@localhost:5432/db"), + ], +) +def test_settings_database_url(url: str, expected_url: str, monkeypatch): + monkeypatch.setenv("ARGILLA_DATABASE_URL", url) + settings = Settings() + assert settings.database_url == expected_url From a4cca14dd27cb9e45a4824f7a10b97e05c39c72d Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 25 Jul 2023 18:24:37 +0200 Subject: [PATCH 15/17] feat: remove `set_database_url_default` validator --- src/argilla/server/settings.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/argilla/server/settings.py b/src/argilla/server/settings.py index a294899119..c8870005f0 100644 --- a/src/argilla/server/settings.py +++ b/src/argilla/server/settings.py @@ -127,12 +127,6 @@ def normalize_base_url(cls, base_url: str): return base_url - @validator("database_url", always=True) - def set_database_url_default(cls, database_url: str, values: dict) -> str: - home_path = values.get("home_path") - sqlite_file = os.path.join(home_path, "argilla.db") - return database_url or f"sqlite+aiosqlite:///{sqlite_file}?check_same_thread=False" - @validator("database_url", pre=True) def set_database_url(cls, database_url: str, values: dict) -> str: if not database_url: From 2dbf40472de7707696e25bad58c2457c2488934f Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Wed, 26 Jul 2023 09:41:55 +0200 Subject: [PATCH 16/17] feat: move `async_db_proxy` fixture to root `conftest` --- tests/conftest.py | 23 +++++++++++++++++++++-- tests/tasks/users/conftest.py | 22 +--------------------- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 933a391dd0..da4e7c5970 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -38,7 +38,7 @@ from fastapi.testclient import TestClient from opensearchpy import OpenSearch from sqlalchemy import create_engine -from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from tests.database import SyncTestSession, TestSession, set_task from tests.factories import ( @@ -54,7 +54,7 @@ from pytest_mock import MockerFixture from sqlalchemy import Connection - from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession + from sqlalchemy.ext.asyncio import AsyncConnection from sqlalchemy.orm import Session @@ -122,6 +122,25 @@ def sync_db(sync_connection: "Connection") -> Generator["Session", None, None]: sync_connection.rollback() +@pytest.fixture +def async_db_proxy(mocker: "MockerFixture", sync_db: "Session") -> "AsyncSession": + """Create a mocked `AsyncSession` that proxies to the sync session. This will allow us to execute the async CLI commands + and then in the unit test function use the sync session to assert the changes. + + Args: + mocker: pytest-mock fixture. + sync_db: Sync session. + + Returns: + Mocked `AsyncSession` that proxies to the sync session. + """ + async_session = AsyncSession() + async_session.sync_session = sync_db + async_session._proxied = sync_db + async_session.close = mocker.AsyncMock() + return async_session + + @pytest.fixture(scope="function") def mock_search_engine(mocker) -> Generator["SearchEngine", None, None]: return mocker.AsyncMock(SearchEngine) diff --git a/tests/tasks/users/conftest.py b/tests/tasks/users/conftest.py index 51679ea23d..a4bad7b5b6 100644 --- a/tests/tasks/users/conftest.py +++ b/tests/tasks/users/conftest.py @@ -15,30 +15,10 @@ from typing import TYPE_CHECKING import pytest -from sqlalchemy.ext.asyncio import AsyncSession if TYPE_CHECKING: from pytest_mock import MockerFixture - from sqlalchemy.orm import Session - - -@pytest.fixture -def async_db_proxy(mocker: "MockerFixture", sync_db: "Session") -> "AsyncSession": - """Create a mocked `AsyncSession` that proxies to the sync session. This will allow us to execute the async CLI commands - and then in the unit test function use the sync session to assert the changes. - - Args: - mocker: pytest-mock fixture. - sync_db: Sync session. - - Returns: - Mocked `AsyncSession` that proxies to the sync session. - """ - async_session = AsyncSession() - async_session.sync_session = sync_db - async_session._proxied = sync_db - async_session.close = mocker.AsyncMock() - return async_session + from sqlalchemy.ext.asyncio import AsyncSession @pytest.fixture(autouse=True) From 569a0bc529715b49373548193b39821bc2cec364 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Wed, 26 Jul 2023 11:06:08 +0200 Subject: [PATCH 17/17] feat: validate `database_url` always --- src/argilla/server/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/argilla/server/settings.py b/src/argilla/server/settings.py index c8870005f0..82ad432f15 100644 --- a/src/argilla/server/settings.py +++ b/src/argilla/server/settings.py @@ -127,7 +127,7 @@ def normalize_base_url(cls, base_url: str): return base_url - @validator("database_url", pre=True) + @validator("database_url", pre=True, always=True) def set_database_url(cls, database_url: str, values: dict) -> str: if not database_url: home_path = values.get("home_path")