From a3dcc0d83057fb8486b3057e6f07d45369719206 Mon Sep 17 00:00:00 2001 From: David Lord Date: Fri, 14 Oct 2022 06:36:59 -0700 Subject: [PATCH] update compatibility with sqlalchemy 2 --- CHANGES.rst | 2 ++ src/flask_sqlalchemy/extension.py | 5 ++++- tests/test_engine.py | 2 +- tests/test_legacy_query.py | 13 +++++++++++++ tests/test_model_name.py | 2 +- 5 files changed, 21 insertions(+), 3 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index b7fb5a30..33f6453d 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,6 +3,8 @@ Version 3.0.2 Unreleased +- Update compatibility with SQLAlchemy 2. :issue:`1122` + Version 3.0.1 ------------- diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 95df831e..85a60350 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -978,8 +978,11 @@ def __getattr__(self, name: str) -> t.Any: if name == "event": return sa.event + if name.startswith("_"): + raise AttributeError(name) + for mod in (sa, sa.orm): - if name in mod.__all__: + if hasattr(mod, name): return getattr(mod, name) raise AttributeError(name) diff --git a/tests/test_engine.py b/tests/test_engine.py index 40a3b4e4..37d9d2e9 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -92,7 +92,7 @@ def test_sqlite_relative_path(app: Flask) -> None: app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///test.db" db = SQLAlchemy(app) db.create_all() - assert isinstance(db.engine.pool, sa.pool.NullPool) + assert not isinstance(db.engine.pool, sa.pool.StaticPool) db_path = db.engine.url.database assert db_path.startswith(app.instance_path) # type: ignore[union-attr] assert os.path.exists(db_path) # type: ignore[arg-type] diff --git a/tests/test_legacy_query.py b/tests/test_legacy_query.py index f08246bc..37bc4290 100644 --- a/tests/test_legacy_query.py +++ b/tests/test_legacy_query.py @@ -1,9 +1,11 @@ from __future__ import annotations import typing as t +import warnings import pytest import sqlalchemy as sa +import sqlalchemy.exc from flask import Flask from werkzeug.exceptions import NotFound @@ -11,6 +13,17 @@ from flask_sqlalchemy.query import Query +@pytest.fixture(autouse=True) +def ignore_query_warning() -> t.Generator[None, None, None]: + if hasattr(sa.exc, "LegacyAPIWarning"): + with warnings.catch_warnings(): + exc = sa.exc.LegacyAPIWarning # type: ignore[attr-defined] + warnings.simplefilter("ignore", exc) + yield + else: + yield + + @pytest.mark.usefixtures("app_ctx") def test_get_or_404(db: SQLAlchemy, Todo: t.Any) -> None: item = Todo() diff --git a/tests/test_model_name.py b/tests/test_model_name.py index 8030a6e3..2c37572d 100644 --- a/tests/test_model_name.py +++ b/tests/test_model_name.py @@ -154,7 +154,7 @@ class Duck(db.Model): class IdMixin: @sa.orm.declared_attr - def id(cls) -> sa.Column[sa.Integer]: # noqa: B902 + def id(cls): # type: ignore[no-untyped-def] # noqa: B902 return sa.Column(sa.Integer, sa.ForeignKey(Duck.id), primary_key=True) class RubberDuck(IdMixin, Duck): # type: ignore[misc]