diff --git a/.github/actions/python_build/action.yml b/.github/actions/python_build/action.yml index a1fa761..9466522 100644 --- a/.github/actions/python_build/action.yml +++ b/.github/actions/python_build/action.yml @@ -20,6 +20,10 @@ inputs: description: "Python version" required: true default: "3.10" + sqlalchemy-version: + description: SQLAlchemy version to run the CI checks on + required: true + default: "2.*" tags: description: "Optional dependencies (via available tags) to install, e.g. [cicd]" required: true @@ -36,5 +40,6 @@ runs: - name: Install dependencies run: | python -m pip install --upgrade pip + pip install sqlalchemy==${{ inputs.sqlalchemy-version }} pip install -e .${{ inputs.tags }} shell: bash diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8621569..9f79aa9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,3 +39,4 @@ jobs: uses: ./.github/workflows/python_ci.yml with: python-version: "3.10" + sqlalchemy-version: "2.*" diff --git a/.github/workflows/ci_dev.yml b/.github/workflows/ci_dev.yml index 4298a97..243411d 100644 --- a/.github/workflows/ci_dev.yml +++ b/.github/workflows/ci_dev.yml @@ -23,6 +23,10 @@ on: description: Python version required: true type: string + sqlalchemy-version: + description: SQLAlchemy version + required: true + type: string jobs: python_ci: @@ -30,3 +34,4 @@ jobs: uses: ./.github/workflows/python_ci.yml with: python-version: ${{ inputs.python-version }} + sqlalchemy-version: ${{ inputs.sqlalchemy-version }} diff --git a/.github/workflows/python_ci.yml b/.github/workflows/python_ci.yml index 8a95a37..4eaf86a 100644 --- a/.github/workflows/python_ci.yml +++ b/.github/workflows/python_ci.yml @@ -22,6 +22,10 @@ on: description: Python version to run the CI checks on required: true type: string + sqlalchemy-version: + description: SQLAlchemy version to run the CI checks on + required: true + type: string defaults: run: @@ -88,6 +92,7 @@ jobs: - uses: ./.github/actions/python_build with: python-version: ${{ inputs.python-version }} + sqlalchemy-version: ${{ inputs.sqlalchemy-version }} tags: "[cicd]" - name: Run pytest with coverage diff --git a/.licenserc.yml b/.licenserc.yml index a983930..125d32d 100644 --- a/.licenserc.yml +++ b/.licenserc.yml @@ -18,9 +18,11 @@ header: paths-ignore: - '**/*.md' - - 'tests/**/test_*/*' + - 'tests/**/test_*/**/*' - '.gitignore' - 'LICENSE' - 'NOTICE' + # symlinks + - 'tests/database/test_dbconnection' comment: on-failure diff --git a/README.md b/README.md index dc35dab..b1a153d 100644 --- a/README.md +++ b/README.md @@ -16,3 +16,22 @@ This library is publicly available in [PyPI](https://pypi.org/project/ensembl-ut ```bash pip install ensembl-utils ``` + +### Quick usage + +Besides the standard `import ensembl.utils`, this library also provides some useful command line scripts: +- `extract_file` - to easily extract archive files in different formats + +Furthermore, `ensembl-utils` also has a [`pytest`](https://docs.pytest.org/) plugin with some useful functionalities to ease your unit testing. You can enable it by adding it explicitly when running pytest: +```bash +pytest -p ensembl.utils.plugin ... +``` + +Or adding the following line to your `conftest.py`: +```python +pytest_plugins = ("ensembl.utils.plugin",) +``` + +## Dependencies + +This repository has been developed to support [SQLAlchemy](https://www.sqlalchemy.org) version 1.4 (1.4.45 or later, to ensure "future-compatibility") as well as version 2.0+. diff --git a/docs/index.md b/docs/index.md index fb9eda0..e048c4e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -6,9 +6,10 @@ Centralise generic Python utils used by other project within Ensembl design to f Check out [installation](install.md) section for further information on how to install the project. 1. [Install](install.md) -2. [Code of Conduct](code_of_conduct.md) -3. [Coverage report](coverage.md) -4. [Code reference](reference/) +2. [Usage](usage.md) +3. [Code of Conduct](code_of_conduct.md) +4. [Coverage report](coverage.md) +5. [Code reference](reference/) ## License Software as part of [Ensembl Python general-purpose utils](https://github.com/Ensembl/ensembl-utils) is distributed under the [Apache-2.0 License](https://www.apache.org/licenses/LICENSE-2.0.txt). diff --git a/docs/usage.md b/docs/usage.md new file mode 100644 index 0000000..cb03709 --- /dev/null +++ b/docs/usage.md @@ -0,0 +1,30 @@ +# Using these utils + +You can easily take advantage of the provided functionalities by importing this library in your code as usual: +```python +import ensembl.utils +``` + +This library also provides some scripts that can help you via the command line: +- `extract_file` - to easily extract archive files in different formats + +_Note:_ All of them include the `--help` option to provide further information about their purpose and how to use them. + +## `pytest` plugin + +This repository provides a [`pytest`](https://docs.pytest.org/) plugin with some useful functionalities to do unit testing. In particular, there is one fixture to access the test files in a folder with the same name as the test being run (`data_dir`) and a fixture to build and provide unit test databases (`test_dbs`). + +To use these elements you need to enable the plugin once you have installed the repository. There are two main ways to do this: +1. Explicitly indicating it when running `pytest`: + ```bash + pytest -p ensembl.utils.plugin ... + ``` + +2. Adding the following line to your `conftest.py` file at the root of where the unit tests are located: + ```python + pytest_plugins = ("ensembl.utils.plugin",) + ``` + +## Dependencies + +This repository has been developed to support [SQLAlchemy](https://www.sqlalchemy.org) version 1.4 (1.4.45 or later, to ensure "future-compatibility") as well as version 2.0+. diff --git a/mkdocs.yml b/mkdocs.yml index 19a213f..e674f45 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -74,6 +74,7 @@ nav: - Home: - Overview: index.md - Install: install.md + - Usage: usage.md - Development: - Code of Conduct: code_of_conduct.md - Coverage report: coverage.md diff --git a/pyproject.toml b/pyproject.toml index 3be4a6b..ecccb11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ keywords = [ classifiers = [ "Development Status :: 4 - Beta", "Environment :: Console", + "Framework :: Pytest", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", @@ -45,10 +46,12 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", ] dependencies = [ + "pytest >= 8.2.0", "python-dotenv >= 0.19.2", "pyyaml ~= 6.0", "requests >= 2.22.0", - "sqlalchemy >= 1.4.0", + "sqlalchemy >= 1.4.45", + "sqlalchemy_utils >= 0.41.2", ] [project.optional-dependencies] @@ -64,9 +67,7 @@ cicd = [ "genbadge[coverage]", "mypy", "pylint", - "pytest", "pytest-dependency", - "requests-mock >= 1.8.0", "types-pyyaml", "types-requests", ] @@ -86,7 +87,8 @@ docs = [ [project.urls] Homepage = "https://www.ensembl.org" Documentation = "https://ensembl.github.io/ensembl-utils/" -Repository = "https://github.com/Ensembl/ensembl-utils" +Repository = "https://github.com/Ensembl/ensembl-utils.git" +Issues = "https://github.com/Ensembl/ensembl-utils/issues" [project.scripts] extract_file = "ensembl.utils.archive:extract_file_cli" diff --git a/src/ensembl/utils/__init__.py b/src/ensembl/utils/__init__.py index 68a3741..5eea5e8 100644 --- a/src/ensembl/utils/__init__.py +++ b/src/ensembl/utils/__init__.py @@ -14,14 +14,14 @@ # limitations under the License. """Ensembl Python general-purpose utils library.""" -__version__ = "0.2.0" +__version__ = "0.3.0" __all__ = [ "StrPath", ] import os -from typing import Union +from typing import TypeVar -StrPath = Union[str, os.PathLike] +StrPath = TypeVar("StrPath", str, os.PathLike) diff --git a/src/ensembl/utils/archive.py b/src/ensembl/utils/archive.py index 677d5ab..61ce9ba 100644 --- a/src/ensembl/utils/archive.py +++ b/src/ensembl/utils/archive.py @@ -14,6 +14,8 @@ # limitations under the License. """Utils for common IO operations over archive files, e.g. tar or gzip.""" +from __future__ import annotations + __all__ = [ "SUPPORTED_ARCHIVE_FORMATS", "open_gz_file", diff --git a/src/ensembl/utils/argparse.py b/src/ensembl/utils/argparse.py index 3959eab..19507b3 100644 --- a/src/ensembl/utils/argparse.py +++ b/src/ensembl/utils/argparse.py @@ -28,6 +28,8 @@ """ +from __future__ import annotations + __all__ = [ "ArgumentParser", ] diff --git a/src/ensembl/utils/database/__init__.py b/src/ensembl/utils/database/__init__.py new file mode 100644 index 0000000..1493814 --- /dev/null +++ b/src/ensembl/utils/database/__init__.py @@ -0,0 +1,20 @@ +# See the NOTICE file distributed with this work for additional information +# regarding copyright ownership. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Database module.""" + +from __future__ import annotations + +from .dbconnection import * +from .unittestdb import * diff --git a/src/ensembl/utils/database/dbconnection.py b/src/ensembl/utils/database/dbconnection.py new file mode 100644 index 0000000..bbc2e97 --- /dev/null +++ b/src/ensembl/utils/database/dbconnection.py @@ -0,0 +1,231 @@ +# See the NOTICE file distributed with this work for additional information +# regarding copyright ownership. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Database connection handler. + +This module provides the main class to connect to and access databases. It will be an ORM-less +connection, that is, the data can only be accessed via SQL queries (see example below). + +Examples: + + >>> from ensembl.utils.database import DBConnection + >>> dbc = DBConnection("mysql://ensro@mysql-server:4242/mydb") + >>> # You can access the database data via sql queries, for instance: + >>> results = dbc.execute("SELECT * FROM my_table;") + >>> # Or via a connection in a transaction manner: + >>> with dbc.begin() as conn: + >>> results = conn.execute("SELECT * FROM my_table;") + +""" + +from __future__ import annotations + +__all__ = [ + "Query", + "StrURL", + "DBConnection", +] + +from contextlib import contextmanager +from typing import ContextManager, Generator, Optional, TypeVar + +import sqlalchemy +from sqlalchemy import create_engine, event, text +from sqlalchemy.orm import sessionmaker + + +Query = TypeVar("Query", str, sqlalchemy.sql.expression.ClauseElement, sqlalchemy.sql.expression.TextClause) +StrURL = TypeVar("StrURL", str, sqlalchemy.engine.URL) + + +class DBConnection: + """Database connection handler, providing also the database's schema and properties. + + Args: + url: URL to the database, e.g. `mysql://user:passwd@host:port/my_db`. + + """ + + def __init__(self, url: StrURL, **kwargs) -> None: + self._engine = create_engine(url, future=True, **kwargs) + self.load_metadata() + + def __repr__(self) -> str: + """Returns a string representation of this object.""" + return f"{self.__class__.__name__}({self.url!r})" + + def load_metadata(self) -> None: + """Loads the metadata information of the database.""" + # Note: Just reflect() is not enough as it would not delete tables that no longer exist + self._metadata = sqlalchemy.MetaData() + self._metadata.reflect(bind=self._engine) + + @property + def url(self) -> str: + """Returns the database URL.""" + return self._engine.url.render_as_string(hide_password=False) + + @property + def db_name(self) -> Optional[str]: + """Returns the database name.""" + return self._engine.url.database + + @property + def host(self) -> Optional[str]: + """Returns the database host.""" + return self._engine.url.host + + @property + def port(self) -> Optional[int]: + """Returns the port of the database host.""" + return self._engine.url.port + + @property + def dialect(self) -> str: + """Returns the SQLAlchemy database dialect name of the database host.""" + return self._engine.name + + @property + def tables(self) -> dict[str, sqlalchemy.schema.Table]: + """Returns the database tables keyed to their name.""" + return self._metadata.tables + + def get_primary_key_columns(self, table: str) -> list[str]: + """Returns the primary key column names for the given table. + + Args: + table: Table name. + + """ + return [col.name for col in self.tables[table].primary_key.columns.values()] + + def get_columns(self, table: str) -> list[str]: + """Returns the column names for the given table. + + Args: + table: Table name. + + """ + return [col.name for col in self.tables[table].columns] + + def connect(self) -> sqlalchemy.engine.Connection: + """Returns a new database connection.""" + return self._engine.connect() + + def begin(self, *args) -> ContextManager[sqlalchemy.engine.Connection]: + """Returns a context manager delivering a database connection with a transaction established.""" + return self._engine.begin(*args) + + def dispose(self) -> None: + """Disposes of the connection pool.""" + self._engine.dispose() + + def execute(self, statement: Query, parameters=None, execution_options=None) -> sqlalchemy.engine.Result: + """Executes the given SQL query and returns its result. + + See `sqlalchemy.engine.Connection.execute()` method for more information about the + additional arguments. + + Args: + statement: SQL query to execute. + parameters: Parameters which will be bound into the statement. + execution_options: Optional dictionary of execution options, which will be associated + with the statement execution. + + """ + if isinstance(statement, str): + statement = text(statement) # type: ignore[assignment] + return self.connect().execute( + statement=statement, parameters=parameters, execution_options=execution_options + ) # type: ignore[call-overload] + + def _enable_sqlite_savepoints(self, engine: sqlalchemy.engine.Engine) -> None: + """Enables SQLite SAVEPOINTS to allow session rollbacks.""" + + @event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): # pylint: disable=unused-argument + """Disables emitting the BEGIN statement entirely, as well as COMMIT before any DDL.""" + dbapi_connection.isolation_level = None + + @event.listens_for(engine, "begin") + def do_begin(conn): + """Emits a customour own BEGIN.""" + conn.exec_driver_sql("BEGIN") + + @contextmanager + def session_scope(self) -> Generator[sqlalchemy.orm.session.Session, None, None]: + """Provides a transactional scope around a series of operations with rollback in case of failure. + + Bear in mind MySQL's storage engine MyISAM does not support rollback transactions, so all + the modifications performed to the database will persist. + + """ + # Create a dedicated engine for this session + engine = create_engine(self._engine.url) + if self.dialect == "sqlite": + self._enable_sqlite_savepoints(engine) + Session = sessionmaker(future=True) + session = Session(bind=engine, autoflush=False) + try: + yield session + session.commit() + except: + # Rollback to ensure no changes are made to the database + session.rollback() + raise + finally: + # Whatever happens, make sure the session is closed + session.close() + + @contextmanager + def test_session_scope(self) -> Generator[sqlalchemy.orm.session.Session, None, None]: + """Provides a transactional scope around a series of operations that will be rolled back at the end. + + Bear in mind MySQL's storage engine MyISAM does not support rollback transactions, so all + the modifications performed to the database will persist. + + """ + # Create a dedicated engine for this session + engine = create_engine(self._engine.url) + if self.dialect == "sqlite": + self._enable_sqlite_savepoints(engine) + # Connect to the database + connection = engine.connect() + # Begin a non-ORM transaction + transaction = connection.begin() + # Bind an individual session to the connection + Session = sessionmaker(future=True) + try: + # Running on SQLAlchemy 2.0+ + session = Session(bind=connection, join_transaction_mode="create_savepoint") + except TypeError: + # Running on SQLAlchemy 1.4 + session = Session(bind=connection) + # If the database supports SAVEPOINT, starting a savepoint will allow to also use rollback + connection.begin_nested() + + # Define a new transaction event + @event.listens_for(session, "after_transaction_end") + def end_savepoint(session, transaction): # pylint: disable=unused-argument + if not connection.in_nested_transaction(): + connection.begin_nested() + + try: + yield session + finally: + # Whatever happens, make sure the session and connection are closed, rolling back + # everything done with the session (including calls to commit()) + session.close() + transaction.rollback() + connection.close() diff --git a/src/ensembl/utils/database/unittestdb.py b/src/ensembl/utils/database/unittestdb.py new file mode 100644 index 0000000..ed64e4a --- /dev/null +++ b/src/ensembl/utils/database/unittestdb.py @@ -0,0 +1,140 @@ +# See the NOTICE file distributed with this work for additional information +# regarding copyright ownership. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit testing database handler. + +This module provides the main class to create and drop testing databases, populating them from +preexisting dumps (if supplied). + +Examples: + + >>> from ensembl.utils.database import UnitTestDB + >>> test_db = UnitTestDB("mysql://user:passwd@mysql-server:4242/", "path/to/dumps", "my_db") + >>> # You can access the database via test_db.dbc, for instance: + >>> results = test_db.dbc.execute("SELECT * FROM my_table;") + >>> # At the end do not forget to drop the database + >>> test_db.drop() + +""" + +from __future__ import annotations + +__all__ = [ + "UnitTestDB", +] + +import os +from pathlib import Path +import subprocess +from typing import Optional + +import sqlalchemy +from sqlalchemy import text +from sqlalchemy.engine import make_url +from sqlalchemy_utils.functions import create_database, database_exists, drop_database + +from ensembl.utils import StrPath +from ensembl.utils.database import DBConnection, StrURL + + +class UnitTestDB: + """Creates and connects to a new test database, applying the schema and importing the data. + + Args: + server_url: URL of the server hosting the database. + dump_dir: Directory path with the database schema in `table.sql` (mandatory) and one TSV data + file (without headers) per table following the convention `.txt` (optional). + name: Name to give to the new database. If not provided, the last directory name of `dump_dir` + will be used instead. In either case, the new database name will be prefixed by the username. + + Attributes: + dbc: Database connection handler. + + Raises: + FileNotFoundError: If `table.sql` is not found. + + """ + + def __init__(self, server_url: StrURL, dump_dir: StrPath, name: Optional[str] = None) -> None: + db_url = make_url(server_url) + dump_dir_path = Path(dump_dir) + db_name = os.environ["USER"] + "_" + (name if name else dump_dir_path.name) + # Add the database name to the URL + if db_url.get_dialect().name == "sqlite": + db_url = db_url.set(database=f"{db_name}.db") + else: + db_url = db_url.set(database=db_name) + # Enable "local_infile" variable for MySQL databases to allow importing data from files + connect_args = {} + if db_url.get_dialect().name == "mysql": + connect_args["local_infile"] = 1 + # Create the database, dropping it beforehand if it already exists + if database_exists(db_url): + drop_database(db_url) + create_database(db_url) + # Establish the connection to the database, load the schema and import the data + try: + self.dbc = DBConnection(db_url, connect_args=connect_args) + with self.dbc.begin() as conn: + # Set InnoDB engine as default and disable foreign key checks for MySQL databases + if self.dbc.dialect == "mysql": + conn.execute(text("SET default_storage_engine=InnoDB")) + conn.execute(text("SET FOREIGN_KEY_CHECKS=0")) + # Load the schema + with open(dump_dir_path / "table.sql", "r") as schema: + for query in "".join(schema.readlines()).split(";"): + if query.strip(): + conn.execute(text(query)) + # And import any available data for each table + for tsv_file in dump_dir_path.glob("*.txt"): + table = tsv_file.stem + self._load_data(conn, table, tsv_file) + # Re-enable foreign key checks for MySQL databases + if self.dbc.dialect == "mysql": + conn.execute(text("SET FOREIGN_KEY_CHECKS=1")) + except: + # Make sure the database is deleted before raising the exception + drop_database(db_url) + raise + # Update the loaded metadata information of the database + self.dbc.load_metadata() + + def __repr__(self) -> str: + """Returns a string representation of this object.""" + return f"{self.__class__.__name__}({self.dbc.url!r})" + + def drop(self) -> None: + """Drops the database.""" + drop_database(self.dbc.url) + # Ensure the connection pool is properly closed and disposed + self.dbc.dispose() + + def _load_data(self, conn: sqlalchemy.engine.Connection, table: str, src: StrPath) -> None: + """Loads the table data from the given file. + + Args: + conn: Open connection to the database. + table: Table name to load the data to. + src: File path with the data in TSV format (without headers). + + """ + if self.dbc.dialect == "sqlite": + # SQLite does not have an equivalent to "LOAD DATA": use its ".import" command instead + subprocess.run(["sqlite3", self.dbc.db_name, ".mode tabs", f".import {src} {table}"], check=True) + elif self.dbc.dialect == "postgresql": + conn.execute(text(f"COPY {table} FROM '{src}'")) + elif self.dbc.dialect == "sqlserver": + conn.execute(text(f"BULK INSERT {table} FROM '{src}'")) + else: + conn.execute(text(f"LOAD DATA LOCAL INFILE '{src}' INTO TABLE {table}")) diff --git a/src/ensembl/utils/logging.py b/src/ensembl/utils/logging.py index 133ee96..f81e705 100644 --- a/src/ensembl/utils/logging.py +++ b/src/ensembl/utils/logging.py @@ -25,6 +25,8 @@ """ +from __future__ import annotations + __all__ = [ "LogLevel", "init_logging", diff --git a/src/ensembl/utils/plugin.py b/src/ensembl/utils/plugin.py new file mode 100644 index 0000000..3dfe89a --- /dev/null +++ b/src/ensembl/utils/plugin.py @@ -0,0 +1,174 @@ +# See the NOTICE file distributed with this work for additional information +# regarding copyright ownership. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ensembl's pytest plugin with useful unit testing hooks and fixtures.""" + +from __future__ import annotations + +from difflib import unified_diff +import os +from pathlib import Path +import re +from typing import Callable, Generator, Optional + +import pytest +from pytest import Config, FixtureRequest, Parser + +from ensembl.utils import StrPath +from ensembl.utils.database import UnitTestDB + + +def pytest_addoption(parser: Parser) -> None: + """Registers argparse-style options for Ensembl's unit testing. + + `Pytest initialisation hook + `_. + + Args: + parser: Parser for command line arguments and ini-file values. + + """ + # Add the Ensembl unitary test parameters to pytest parser + group = parser.getgroup("Ensembl unit testing") + group.addoption( + "--server", + action="store", + metavar="URL", + dest="server", + required=False, + default=os.getenv("DB_HOST", "sqlite:///"), + help="Server URL where to create the test database(s)", + ) + group.addoption( + "--keep-dbs", + action="store_true", + dest="keep_dbs", + required=False, + help="Do not remove the test databases (default: False)", + ) + + +def pytest_report_header(config: Config) -> str: + """Presents extra information in the report header. + + Args: + config: Access to configuration values, pluginmanager and plugin hooks. + + """ + # Show server information, masking the password value + server = config.getoption("server") + server = re.sub(r"(//[^/]+:).*(@)", r"\1xxxxxx\2", server) + return f"server: {server}" + + +@pytest.fixture(name="data_dir", scope="module") +def local_data_dir(request: FixtureRequest) -> Path: + """Returns the path to the test data folder matching the test's name. + + Args: + request: Fixture that provides information of the requesting test function. + + """ + return Path(request.module.__file__).with_suffix("") + + +@pytest.fixture(name="assert_files") +def fixture_assert_files() -> Callable[[StrPath, StrPath], None]: + """Returns a function that asserts if two text files are equal, or prints their differences.""" + + def _assert_files(result_path: StrPath, expected_path: StrPath) -> None: + """Asserts if two files are equal, or prints their differences. + + Args: + result_path: Path to results (test-made) file. + expected_path: Path to expected file. + + """ + with open(result_path, "r") as result_fh: + results = result_fh.readlines() + with open(expected_path, "r") as expected_fh: + expected = expected_fh.readlines() + files_diff = list( + unified_diff( + results, + expected, + fromfile=f"Test-made file {Path(result_path).name}", + tofile=f"Expected file {Path(expected_path).name}", + ) + ) + assert_message = f"Test-made and expected files differ\n{' '.join(files_diff)}" + assert len(files_diff) == 0, assert_message + + return _assert_files + + +@pytest.fixture(name="db_factory", scope="module") +def fixture_db_factory(request: FixtureRequest, data_dir: Path) -> Generator[Callable, None, None]: + """Yields a unit test database factory. + + Args: + request: Fixture that provides information of the requesting test function. + data_dir: Fixture that provides the path to the test data folder matching the test's name. + + """ + created: dict[str, UnitTestDB] = {} + server_url = request.config.getoption("server") + + def _db_factory(src: StrPath, name: Optional[str] = None) -> UnitTestDB: + """Returns a unit test database. + + Args: + src: Directory path where the test database schema and content files are located. + name: Name to give to the new database. See `UnitTestDB` for more information. + + """ + src_path = Path(src) + if not src_path.is_absolute(): + src_path = data_dir / src_path + db_key = name if name else src_path.name + return created.setdefault(db_key, UnitTestDB(server_url, src_path, name)) + + yield _db_factory + # Drop all unit test databases unless the user has requested to keep them + if not request.config.getoption("keep_dbs"): + for test_db in created.values(): + test_db.drop() + + +@pytest.fixture(scope="module") +def test_dbs(request: FixtureRequest, db_factory: Callable) -> dict[str, UnitTestDB]: + """Returns a dictionary of unit test databases with the database name as key. + + Requires a list of dictionaries, each with keys `src` (mandatory) and `name` (optional), passed via + `request.param`. See `db_factory()` for details about each key's value. This fixture is a wrapper of + `db_factory()` intended to be used via indirect parametrization, for example:: + + @pytest.mark.parametrize( + "test_dbs", [[{"src": "master"}, {"src": "master", "name": "master2"}]], indirect=True + ) + def test_method(..., test_dbs: dict[str, UnitTestDB], ...): + + + Args: + request: Fixture that provides information of the requesting test function. + db_factory: Fixture that provides a unit test database factory. + + """ + databases = {} + for argument in request.param: + src = Path(argument["src"]) + name = argument.get("name", None) + key = name if name else src.name + databases[key] = db_factory(src, name) + return databases diff --git a/src/ensembl/utils/rloader.py b/src/ensembl/utils/rloader.py index c9c6af3..31fb134 100644 --- a/src/ensembl/utils/rloader.py +++ b/src/ensembl/utils/rloader.py @@ -14,6 +14,8 @@ # limitations under the License. """Allow to seamlessly load / read the content of a remote file as if it was located locally.""" +from __future__ import annotations + __all__ = ["RemoteFileLoader"] import configparser diff --git a/tests/conftest.py b/tests/conftest.py index 2a86cdd..3a5f004 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,18 +14,4 @@ # limitations under the License. """Local directory-specific plugin imports, and hook and fixture implementations.""" -from pathlib import Path - -import pytest -from pytest import FixtureRequest - - -@pytest.fixture(name="data_dir", scope="module") -def local_data_dir(request: FixtureRequest) -> Path: - """Returns the path to the test data folder matching the test's name. - - Args: - request: Fixture providing information of the requesting test function. - - """ - return Path(request.module.__file__).with_suffix("") +pytest_plugins = ("ensembl.utils.plugin",) diff --git a/tests/database/test_dbconnection b/tests/database/test_dbconnection new file mode 120000 index 0000000..ec5a50f --- /dev/null +++ b/tests/database/test_dbconnection @@ -0,0 +1 @@ +test_unittestdb \ No newline at end of file diff --git a/tests/database/test_dbconnection.py b/tests/database/test_dbconnection.py new file mode 100644 index 0000000..c1bb52a --- /dev/null +++ b/tests/database/test_dbconnection.py @@ -0,0 +1,234 @@ +# See the NOTICE file distributed with this work for additional information +# regarding copyright ownership. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit testing of `ensembl.utils.database.dbconnection` module.""" + +from contextlib import nullcontext as does_not_raise +import os +from typing import ContextManager + +import pytest +from pytest import FixtureRequest, param, raises +from sqlalchemy import text +from sqlalchemy.engine.url import make_url +from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy.ext.automap import automap_base + +from ensembl.utils.database import DBConnection, Query, UnitTestDB + + +@pytest.mark.parametrize("test_dbs", [[{"src": "mock_db"}]], indirect=True) +class TestDBConnection: + """Tests `DBConnection` class. + + Attributes: + dbc: Test database connection. + server: Server URL where the test database is hosted. + + """ + + dbc: DBConnection = None + server: str = "" + + @pytest.fixture(scope="class", autouse=True) + @classmethod + def setup(cls, request: FixtureRequest, test_dbs: dict[str, UnitTestDB]) -> None: + """Loads the required fixtures and values as class attributes. + + Args: + request: Fixture that provides information of the requesting test function. + test_dbs: Fixture that provides the unit test databases. + + """ + cls.dbc = test_dbs["mock_db"].dbc + cls.server = request.config.getoption("server") + + @pytest.mark.dependency(name="test_init", scope="class") + def test_init(self) -> None: + """Tests that the object `DBConnection` is initialised correctly.""" + assert self.dbc, "DBConnection object should not be empty" + + @pytest.mark.dependency(name="test_dialect", depends=["test_init"], scope="class") + def test_dialect(self) -> None: + """Tests `DBConnection.dialect` property.""" + assert self.dbc.dialect == make_url(self.server).drivername + + @pytest.mark.dependency(name="test_db_name", depends=["test_init", "test_dialect"], scope="class") + def test_db_name(self) -> None: + """Tests `DBConnection.db_name` property.""" + expected_db_name = f"{os.environ['USER']}_mock_db" + if self.dbc.dialect == "sqlite": + expected_db_name += ".db" + assert self.dbc.db_name == expected_db_name + + @pytest.mark.dependency(depends=["test_init", "test_dialect", "test_db_name"], scope="class") + def test_url(self) -> None: + """Tests `DBConnection.url` property.""" + expected_url = make_url(self.server).set(database=self.dbc.db_name) + assert self.dbc.url == expected_url.render_as_string(hide_password=False) + + @pytest.mark.dependency(depends=["test_init"], scope="class") + def test_host(self) -> None: + """Tests `DBConnection.host` property.""" + assert self.dbc.host == make_url(self.server).host + + @pytest.mark.dependency(depends=["test_init"], scope="class") + def test_port(self) -> None: + """Tests `DBConnection.port` property.""" + assert self.dbc.port == make_url(self.server).port + + @pytest.mark.dependency(depends=["test_init"], scope="class") + def test_tables(self) -> None: + """Tests `DBConnection.tables` property.""" + assert set(self.dbc.tables.keys()) == {"gibberish"} + + @pytest.mark.dependency(depends=["test_init"], scope="class") + def test_get_primary_key_columns(self) -> None: + """Tests `DBConnection.get_primary_key_columns()` method.""" + table = "gibberish" + assert set(self.dbc.get_primary_key_columns(table)) == { + "id", + "grp", + }, f"Unexpected set of primary key columns found in table '{table}'" + + @pytest.mark.dependency(depends=["test_init"], scope="class") + def test_get_columns(self) -> None: + """Tests `DBConnection.get_columns()` method.""" + table = "gibberish" + assert set(self.dbc.get_columns(table)) == { + "id", + "grp", + "value", + }, f"Unexpected set of columns found in table '{table}'" + + @pytest.mark.dependency(name="test_connect", depends=["test_init"], scope="class") + def test_connect(self) -> None: + """Tests `DBConnection.connect()` method.""" + connection = self.dbc.connect() + assert connection, "Connection object should not be empty" + result = connection.execute(text("SELECT * FROM gibberish")) + assert len(result.fetchall()) == 6, "Unexpected number of rows found in 'gibberish' table" + connection.close() + + @pytest.mark.dependency(depends=["test_init"], scope="class") + def test_begin(self) -> None: + """Tests `DBConnection.begin()` method.""" + with self.dbc.begin() as connection: + assert connection, "Connection object should not be empty" + result = connection.execute(text("SELECT * FROM gibberish")) + assert len(result.fetchall()) == 6, "Unexpected number of rows found in 'gibberish' table" + + @pytest.mark.dependency(depends=["test_init"], scope="class") + def test_dispose(self) -> None: + """Tests `DBConnection.dispose()` method.""" + self.dbc.dispose() + # SQLAlchemy uses a "pool-less" connection system for SQLite + if self.dbc.dialect != "sqlite": + num_conn = self.dbc._engine.pool.checkedin() # pylint: disable=protected-access + assert num_conn == 0, "A new pool should have 0 checked-in connections" + + @pytest.mark.dependency(name="test_exec", depends=["test_init"], scope="class") + @pytest.mark.parametrize( + "query, nrows, expectation", + [ + param("SELECT * FROM gibberish", 6, does_not_raise(), id="Valid string query"), + param(text("SELECT * FROM gibberish"), 6, does_not_raise(), id="Valid text query"), + param( + "SELECT * FROM my_table", + 0, + raises(SQLAlchemyError, match=r"(my_table.* doesn't exist|no such table: my_table)"), + id="Querying an unexistent table", + ), + ], + ) + def test_execute(self, query: Query, nrows: int, expectation: ContextManager) -> None: + """Tests `DBConnection.execute()` method. + + Args: + query: SQL query. + nrows: Number of rows expected to be returned from the query. + expectation: Context manager for the expected exception. + + """ + with expectation: + result = self.dbc.execute(query) + assert len(result.fetchall()) == nrows, "Unexpected number of rows returned" + + @pytest.mark.dependency(depends=["test_init", "test_connect", "test_exec"], scope="class") + @pytest.mark.parametrize( + "identifier, rows_to_add, before, after", + [ + param(7, [{"grp": "grp4", "value": 1}, {"grp": "grp5", "value": 1}], 0, 2, id="Add new data"), + param( + 7, [{"grp": "grp6", "value": 1}, {"grp": "grp6", "value": 2}], 2, 2, id="Add existing data" + ), + ], + ) + def test_session_scope( + self, identifier: int, rows_to_add: list[dict[str, str]], before: int, after: int + ) -> None: + """Tests `DBConnection.session_scope()` method. + + Bear in mind that the second parameterization of this test will fail if the dialect/table engine + does not support rollback transactions. + + Args: + identifier: ID of the rows to add. + rows_to_add: Rows to add to the `gibberish` table. + before: Number of rows in `gibberish` table for `id` before adding the rows. + after: Number of rows in `gibberish` table for `id` after adding the rows. + + """ + query = f"SELECT * FROM gibberish WHERE id = {identifier}" + results = self.dbc.execute(query) + assert len(results.fetchall()) == before + # Session requires mapped classes to interact with the database + Base = automap_base() + Base.prepare(autoload_with=self.dbc.connect()) + Gibberish = Base.classes.gibberish + # Ignore IntegrityError raised when committing the new tags as some parametrizations will force it + try: + with self.dbc.session_scope() as session: + rows = [Gibberish(id=identifier, **x) for x in rows_to_add] + session.add_all(rows) + except IntegrityError: + pass + results = self.dbc.execute(query) + assert len(results.fetchall()) == after + + @pytest.mark.dependency(depends=["test_init", "test_connect", "test_exec"], scope="class") + def test_test_session_scope(self) -> None: + """Tests `DBConnection.test_session_scope()` method.""" + # Session requires mapped classes to interact with the database + Base = automap_base() + Base.prepare(autoload_with=self.dbc.connect()) + Gibberish = Base.classes.gibberish + # Check that the tags added during the context manager are removed afterwards + identifier = 8 + with self.dbc.test_session_scope() as session: + results = session.query(Gibberish).filter_by(id=identifier) + assert not results.all(), f"ID {identifier} should not have any entries" + session.add(Gibberish(id=identifier, grp="grp7", value=15)) + session.add(Gibberish(id=identifier, grp="grp8", value=25)) + session.commit() + results = session.query(Gibberish).filter_by(id=identifier) + assert len(results.all()) == 2, f"ID {identifier} should have two rows" + results = self.dbc.execute(f"SELECT * FROM gibberish WHERE id = {identifier}") + if ( + self.dbc.dialect == "mysql" + and self.dbc.tables["gibberish"].dialect_options["mysql"]["engine"] == "MyISAM" + ): + assert len(results.all()) == 2, f"SQLite/MyISAM: 2 rows permanently added to ID {identifier}" + else: + assert not results.fetchall(), f"No entries should have been permanently added to ID {identifier}" diff --git a/tests/database/test_unittestdb.py b/tests/database/test_unittestdb.py new file mode 100644 index 0000000..b8b5883 --- /dev/null +++ b/tests/database/test_unittestdb.py @@ -0,0 +1,96 @@ +# See the NOTICE file distributed with this work for additional information +# regarding copyright ownership. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit testing of `ensembl.utils.database.unittestdb` module.""" + +from contextlib import nullcontext as does_not_raise +from pathlib import Path +from typing import ContextManager, Optional + +import pytest +from pytest import FixtureRequest, param, raises +from sqlalchemy_utils.functions import database_exists + +from ensembl.utils.database import UnitTestDB + + +class TestUnitTestDB: + """Tests `UnitTestDB` class. + + Attributes: + dbs: Dictionary of `UnitTestDB` objects with the database name as key. + + """ + + dbs: dict[str, UnitTestDB] = {} + + @pytest.mark.dependency(name="test_init", scope="class") + @pytest.mark.parametrize( + "src, name, expectation", + [ + param(Path("mock_db"), None, does_not_raise(), id="Default test database creation"), + param(Path("mock_db"), "renamed_db", does_not_raise(), id="Rename test database"), + param(Path("mock_db"), None, does_not_raise(), id="Re-create mock db with absolute path"), + param(Path("mock_dir"), None, raises(FileNotFoundError), id="Wrong dump folder"), + ], + ) + def test_init( + self, + request: FixtureRequest, + data_dir: Path, + src: Path, + name: Optional[str], + expectation: ContextManager, + ) -> None: + """Tests that the object `UnitTestDB` is initialised correctly. + + Args: + request: Fixture that provides information of the requesting test function. + data_dir: Fixture that provides the path to the test data folder matching the test's name. + src: Directory path with the database schema and one TSV data file per table. + name: Name to give to the new database. + expectation: Context manager for the expected exception. + + """ + with expectation: + server_url = request.config.getoption("server") + src_path = src if src.is_absolute() else data_dir / src + db_key = name if name else src.name + self.dbs[db_key] = UnitTestDB(server_url, src_path, name) + # Check that the database has been created correctly + assert self.dbs[db_key], "UnitTestDB should not be empty" + assert self.dbs[db_key].dbc, "UnitTestDB's database connection should not be empty" + # Check that the database has been loaded correctly from the dump files + result = self.dbs[db_key].dbc.execute("SELECT * FROM gibberish") + assert len(result.fetchall()) == 6, "Unexpected number of rows found in 'gibberish' table" + + @pytest.mark.dependency(depends=["test_init"], scope="class") + @pytest.mark.parametrize( + "db_key", + [ + param("mock_db"), + param("renamed_db"), + ], + ) + def test_drop(self, db_key: str) -> None: + """Tests the `UnitTestDB.drop()` method. + + Args: + db_key: Key assigned to the UnitTestDB created in `TestUnitTestDB.test_init()`. + + """ + db_url = self.dbs[db_key].dbc.url + assert database_exists(db_url) + self.dbs[db_key].drop() + assert not database_exists(db_url) diff --git a/tests/database/test_unittestdb/mock_db/gibberish.txt b/tests/database/test_unittestdb/mock_db/gibberish.txt new file mode 100644 index 0000000..cbc82f3 --- /dev/null +++ b/tests/database/test_unittestdb/mock_db/gibberish.txt @@ -0,0 +1,6 @@ +1 grp1 11 +2 grp1 12 +3 grp2 21 +4 grp2 22 +5 grp2 23 +6 grp3 31 diff --git a/tests/database/test_unittestdb/mock_db/table.sql b/tests/database/test_unittestdb/mock_db/table.sql new file mode 100644 index 0000000..7b1e856 --- /dev/null +++ b/tests/database/test_unittestdb/mock_db/table.sql @@ -0,0 +1,7 @@ +CREATE TABLE `gibberish` ( + `id` INTEGER NOT NULL, + `grp` VARCHAR(20) DEFAULT "", + `value` INT DEFAULT NULL, + PRIMARY KEY (`id`, `grp`) +); +CREATE INDEX `id_idx` ON `gibberish` (`id`); diff --git a/tests/plugin/test_plugin.py b/tests/plugin/test_plugin.py new file mode 100644 index 0000000..8e4d71e --- /dev/null +++ b/tests/plugin/test_plugin.py @@ -0,0 +1,106 @@ +# See the NOTICE file distributed with this work for additional information +# regarding copyright ownership. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit testing of `ensembl.utils.plugin` module. + +Since certain elements are embedded within pytest itself, only the fixtures are unit tested in this case. +""" + +from contextlib import nullcontext as does_not_raise +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, ContextManager +from unittest.mock import patch + +import pytest +from pytest import FixtureRequest, param, raises + +from ensembl.utils import StrPath +from ensembl.utils.database import StrURL + + +@dataclass +class MockTestDB: + """Mocks `UnitTestDB` class by just storing the three arguments provided.""" + + server_url: StrURL + dump_dir: StrPath + name: str + + def drop(self) -> None: + """Mocks `UnitTestDB.drop()` method.""" + + +@pytest.mark.dependency(name="test_data_dir") +def test_data_dir(request: FixtureRequest, data_dir: Path) -> None: + """Tests the `data_dir` fixture. + + Args: + request: Fixture that provides information of the requesting test function. + data_dir: Fixture that provides the path to the test data folder matching the test's name. + + """ + assert data_dir.stem == request.path.stem + + +@pytest.mark.dependency(depends=["test_data_dir"]) +@pytest.mark.parametrize( + "left, right, expectation", + [ + param("file1.txt", "file1.txt", does_not_raise(), id="Files are equal"), + param("file1.txt", "file2.txt", raises(AssertionError), id="Files differ"), + ], +) +def test_assert_files( + assert_files: Callable, data_dir: Path, left: str, right: str, expectation: ContextManager +) -> None: + """Tests the `assert_files` fixture. + + Args: + assert_files: Fixture that provides an assertion function to compare two files. + data_dir: Fixture that provides the path to the test data folder matching the test's name. + left: Left file to compare. + right: Right file to compare. + expectation: Context manager for the expected exception. + + """ + with expectation: + assert_files(data_dir / left, data_dir / right) + + +@pytest.mark.parametrize( + "dump_dir, db_name", + [ + (Path("dump_dir"), "dump_dir"), + (Path("dump_dir").resolve(), "test_db"), + ], +) +@patch("ensembl.utils.plugin.UnitTestDB", new=MockTestDB) +def test_db_factory(request: FixtureRequest, db_factory: Callable, dump_dir: Path, db_name: str) -> None: + """Tests the `db_factory` fixture. + + Args: + request: Fixture that provides information of the requesting test function. + db_factory: Fixture that provides a unit test database factory. + dump_dir: Directory path where the test database schema and content files are located. + db_name: Name to give to the new database. + + """ + test_db = db_factory(dump_dir, db_name) + assert test_db.server_url == request.config.getoption("server") + assert test_db.name == db_name + if dump_dir.is_absolute(): + assert test_db.dump_dir == dump_dir + else: + assert test_db.dump_dir.stem == str(dump_dir) diff --git a/tests/plugin/test_plugin/file1.txt b/tests/plugin/test_plugin/file1.txt new file mode 100644 index 0000000..b6fc4c6 --- /dev/null +++ b/tests/plugin/test_plugin/file1.txt @@ -0,0 +1 @@ +hello \ No newline at end of file diff --git a/tests/plugin/test_plugin/file2.txt b/tests/plugin/test_plugin/file2.txt new file mode 100644 index 0000000..0abaeaa --- /dev/null +++ b/tests/plugin/test_plugin/file2.txt @@ -0,0 +1 @@ +bye \ No newline at end of file diff --git a/tests/logging/test_logging.py b/tests/test_logging.py similarity index 100% rename from tests/logging/test_logging.py rename to tests/test_logging.py