Skip to content

Commit

Permalink
feat: update CLI to use async connection to DB (#3450)
Browse files Browse the repository at this point in the history
# Description

This PR updates the CLI, so an async connection to the DB is used in the
commands instead of a sync one. As the rest of the parts of the
application are using the async connection, the code for the sync
connection has been removed.

The list of changes of this PR are:

- Add `AsyncTyper` class which allows to register `async` command with
its method `async_command`. It executes then executes the command using
`asyncio.run`.
- Add `run` function that allows to execute an `async` command.
- Remove `database_async_url` property from `Settings` class. From now
on, `database_url` should have an URL with an async driver
(`sqlite+aiosqlite://` or `postgres+asyncpg://`).
- Update Alembic connection to create an async engine to run the
migrations.

**Type of change**

- [x] New feature (non-breaking change which adds functionality)
- [x] Breaking change (fix or feature that would cause existing
functionality to not work as expected)

**How Has This Been Tested**

- [x] All the unit tests regarding the CLI are working as expected
- [x] All the commands have been executed locally and working as
expected (using both `python -m argilla users create ...` and `python -m
argilla.tasks.users.create ...`)
- [x] Migrations have been applied without errors in a local environment
to both SQLite and PostgreSQL instances.

**Checklist**

- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [x] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: Francisco Aranda <[email protected]>
  • Loading branch information
gabrielmbmb and frascuchon authored Jul 26, 2023
1 parent af60012 commit ef1eb16
Show file tree
Hide file tree
Showing 25 changed files with 283 additions and 218 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ These are the section headers that we use:
- Improved efficiency of weak labeling when dataset contains vectors ([#3444](https://github.com/argilla-io/argilla/pull/3444)).
- Added `ArgillaDatasetMixin` to detach the Argilla-related functionality from the `FeedbackDataset` ([#3427](https://github.com/argilla-io/argilla/pull/3427))
- Moved `FeedbackDataset`-related `pydantic.BaseModel` schemas to `argilla.client.feedback.schemas` instead, to be better structured and more scalable and maintainable ([#3427](https://github.com/argilla-io/argilla/pull/3427))
- Update CLI to use database async connection ([#3450](https://github.com/argilla-io/argilla/pull/3450)).
- Update alembic code to apply migrations to use database async engine ([#3450](https://github.com/argilla-io/argilla/pull/3450)).
- Limit rating questions values to the positive range [1, 10] (Closes [#3451](https://github.com/argilla-io/argilla/issues/3451)).

## [1.13.2](https://github.com/argilla-io/argilla/compare/v1.13.1...v1.13.2)
Expand Down
7 changes: 3 additions & 4 deletions src/argilla/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
# limitations under the License.


import typer
from argilla.tasks import database_app, server_app, training_app, users_app
from argilla.tasks.async_typer import AsyncTyper

from .tasks import database_app, server_app, training_app, users_app

app = typer.Typer(rich_help_panel=True, help="Argilla CLI", no_args_is_help=True)
app = AsyncTyper(rich_help_panel=True, help="Argilla CLI", no_args_is_help=True)

app.add_typer(users_app, name="users")
app.add_typer(database_app, name="database")
Expand Down
28 changes: 19 additions & 9 deletions src/argilla/server/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from logging.config import fileConfig
from typing import TYPE_CHECKING

from alembic import context
from argilla.server.models.base import DatabaseModel
from argilla.server.models.models import * # noqa
from argilla.server.settings import settings
from sqlalchemy import engine_from_config, pool
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import async_engine_from_config

if TYPE_CHECKING:
from sqlalchemy import Connection

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
Expand Down Expand Up @@ -68,27 +74,31 @@ def run_migrations_offline() -> None:
context.run_migrations()


def run_migrations_online() -> None:
def apply_migrations(connection: "Connection") -> None:
context.configure(connection=connection, target_metadata=target_metadata)

with context.begin_transaction():
context.run_migrations()


async def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
connectable = async_engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)

with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)

with context.begin_transaction():
context.run_migrations()
async with connectable.connect() as connection:
await connection.run_sync(apply_migrations)


if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
asyncio.run(run_migrations_online())
3 changes: 1 addition & 2 deletions src/argilla/server/apis/v1/handlers/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@

from fastapi import APIRouter, Depends, HTTPException, Security, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from argilla.server.contexts import datasets
from argilla.server.database import get_async_db, get_db
from argilla.server.database import get_async_db
from argilla.server.models import User
from argilla.server.policies import ResponsePolicyV1, authorize
from argilla.server.schemas.v1.responses import Response, ResponseUpdate
Expand Down
21 changes: 5 additions & 16 deletions src/argilla/server/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@
from sqlite3 import Connection as SQLite3Connection
from typing import TYPE_CHECKING, Generator

from sqlalchemy import create_engine, event
from sqlalchemy import event
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine

import argilla
from argilla.server.settings import settings

if TYPE_CHECKING:
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession


ALEMBIC_CONFIG_FILE = os.path.normpath(os.path.join(os.path.dirname(argilla.__file__), "alembic.ini"))
TAGGED_REVISIONS = OrderedDict(
Expand All @@ -46,21 +46,10 @@ def set_sqlite_pragma(dbapi_connection, connection_record):
cursor.close()


engine = create_engine(settings.database_url)
SessionLocal = sessionmaker(autocommit=False, bind=engine)

async_engine = create_async_engine(settings.database_url_async)
async_engine = create_async_engine(settings.database_url)
AsyncSessionLocal = async_sessionmaker(autocommit=False, expire_on_commit=False, bind=async_engine)


def get_db() -> Generator["Session", None, None]:
try:
db = SessionLocal()
yield db
finally:
db.close()


async def get_async_db() -> Generator["AsyncSession", None, None]:
try:
db: "AsyncSession" = AsyncSessionLocal()
Expand Down
62 changes: 0 additions & 62 deletions src/argilla/server/seeds.py

This file was deleted.

46 changes: 30 additions & 16 deletions src/argilla/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
"""
import logging
import os
import re
import warnings
from pathlib import Path
from typing import List, Optional
from urllib.parse import urlparse
Expand Down Expand Up @@ -125,9 +127,34 @@ def normalize_base_url(cls, base_url: str):

return base_url

@validator("database_url", always=True)
def set_database_url_default(cls, database_url: str, values: dict) -> str:
return database_url or f"sqlite:///{os.path.join(values['home_path'], 'argilla.db')}?check_same_thread=False"
@validator("database_url", pre=True, always=True)
def set_database_url(cls, database_url: str, values: dict) -> str:
if not database_url:
home_path = values.get("home_path")
sqlite_file = os.path.join(home_path, "argilla.db")
return f"sqlite+aiosqlite:///{sqlite_file}?check_same_thread=False"

if "sqlite" in database_url:
regex = re.compile(r"sqlite(?!\+aiosqlite)")
if regex.match(database_url):
warnings.warn(
"From version 1.14.0, Argilla will use `aiosqlite` as default SQLite driver. The protocol in the"
" provided database URL has been automatically replaced from `sqlite` to `sqlite+aiosqlite`."
" Please, update your database URL to use `sqlite+aiosqlite` protocol."
)
return re.sub(regex, "sqlite+aiosqlite", database_url)

if "postgresql" in database_url:
regex = re.compile(r"postgresql(?!\+asyncpg)(\+psycopg2)?")
if regex.match(database_url):
warnings.warn(
"From version 1.14.0, Argilla will use `asyncpg` as default PostgreSQL driver. The protocol in the"
" provided database URL has been automatically replaced from `postgresql` to `postgresql+asyncpg`."
" Please, update your database URL to use `postgresql+asyncpg` protocol."
)
return re.sub(regex, "postgresql+asyncpg", database_url)

return database_url

@root_validator(skip_on_failure=True)
def create_home_path(cls, values):
Expand Down Expand Up @@ -165,19 +192,6 @@ def old_dataset_records_index_name(self) -> str:
return index_name.replace("<NAMESPACE>", "")
return index_name.replace("<NAMESPACE>", f".{ns}")

@property
def database_url_async(self) -> str:
if self.database_url.startswith("sqlite:///"):
return self.database_url.replace("sqlite:///", "sqlite+aiosqlite:///")

if self.database_url.startswith("postgresql://"):
return self.database_url.replace("postgresql://", "postgresql+asyncpg://")

if self.database_url.startswith("mysql://"):
return self.database_url.replace("mysql://", "mysql+aiomysql://")

raise ValueError(f"Unsupported database url: '{self.database_url}'")

def obfuscated_elasticsearch(self) -> str:
"""Returns configured elasticsearch url obfuscating the provided password, if any"""
parsed = urlparse(self.elasticsearch)
Expand Down
57 changes: 57 additions & 0 deletions src/argilla/tasks/async_typer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import sys
from functools import wraps
from typing import Any, Callable, Coroutine, TypeVar

import typer

if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
else:
from typing import ParamSpec


P = ParamSpec("P")
R = TypeVar("R")


# https://github.com/tiangolo/typer/issues/88#issuecomment-1613013597
class AsyncTyper(typer.Typer):
def command(
self, *args: Any, **kwargs: Any
) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]:
super_command = super().command(*args, **kwargs)

def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@wraps(func)
def sync_func(*_args: P.args, **_kwargs: P.kwargs) -> R:
return asyncio.run(func(*_args, **_kwargs))

if asyncio.iscoroutinefunction(func):
super_command(sync_func)
else:
super_command(func)

return func

return decorator


def run(function: Callable[..., Coroutine[Any, Any, Any]]) -> None:
app = AsyncTyper(add_completion=False)
app.command()(function)
app()
3 changes: 2 additions & 1 deletion src/argilla/tasks/database/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from alembic.util import CommandError

from argilla.server.database import ALEMBIC_CONFIG_FILE, TAGGED_REVISIONS
from argilla.tasks import async_typer
from argilla.tasks.database import utils


Expand Down Expand Up @@ -47,4 +48,4 @@ def migrate_db(revision: Optional[str] = typer.Option(default="head", help="DB R


if __name__ == "__main__":
typer.run(migrate_db)
async_typer.run(migrate_db)
3 changes: 2 additions & 1 deletion src/argilla/tasks/database/revisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import typer

from argilla.server.database import ALEMBIC_CONFIG_FILE, TAGGED_REVISIONS
from argilla.tasks import async_typer
from argilla.tasks.database import utils


Expand All @@ -39,4 +40,4 @@ def revisions():


if __name__ == "__main__":
typer.run(revisions)
async_typer.run(revisions)
1 change: 0 additions & 1 deletion src/argilla/tasks/training/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def train(
):
import json

import argilla as rg
from argilla.client.api import init
from argilla.training import ArgillaTrainer

Expand Down
4 changes: 2 additions & 2 deletions src/argilla/tasks/users/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import typer
from argilla.tasks.async_typer import AsyncTyper

from .create import create
from .create_default import create_default
from .migrate import migrate
from .update import update

app = typer.Typer(help="Holds CLI commands for user and workspace management.", no_args_is_help=True)
app = AsyncTyper(help="Holds CLI commands for user and workspace management.", no_args_is_help=True)

app.command(name="create_default", help="Creates default users and workspaces in the Argilla database.")(create_default)
app.command(name="create", help="Creates a user and add it to the Argilla database.", no_args_is_help=True)(create)
Expand Down
Loading

0 comments on commit ef1eb16

Please sign in to comment.