Skip to content

Commit

Permalink
SQLAlchemy 2.0 Declarative Syntax (pypi#14266)
Browse files Browse the repository at this point in the history
* Step 1 of moving to Declarative syntax

Refs: https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#step-one-orm-declarative-base-is-superseded-by-orm-declarativebase

- Replace abstract concrete base with simpler pattern
  Since we don't actually care about the polymorphism, we can simplify and
  have the models map correctly. Follows pattern as shown in the examples:
  https://docs.sqlalchemy.org/en/20/_modules/examples/generic_associations/table_per_related.html

- Minor modification to Index definition so we don't generate new indices as a
  result.

- Update DBML tests to use the base model, override metadata

Signed-off-by: Mike Fiedler <[email protected]>

* Step 2 of moving to Declarative syntax

In this step, we find&replace instances of `Column()` with
`mapped_column()` which presents an identical API.

The columns are still `Mapped[Any]` at this stage, those should be added
in Step 3.

Refs: https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#step-two-replace-declarative-use-of-schema-column-with-orm-mapped-column

Two Tables have not been converted to the declarative syntax, as they do
not have a primary key set yet. That needs to be handled out of band, as
there's data quality/migrations that need to happen first.

Signed-off-by: Mike Fiedler <[email protected]>

* Step 3 - accounts models

Apply changes to the models in the Accounts segment.
https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#step-three-apply-exact-python-types-as-needed-using-orm-mapped

Adds verbosity that should be simplified in Step 4.

Signed-off-by: Mike Fiedler <[email protected]>

* fix: apply nullable changes to match database

Towards Step 4, apply these qualifiers so that autogenerated migrations
perform no changes.

Signed-off-by: Mike Fiedler <[email protected]>

* Step 4 - clean up mapped_column for account models

Now that we're past mapping the datatypes, we can simplify the
declarations by removing unnecessary details.

Refs: https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#step-four-remove-orm-mapped-column-directives-where-no-longer-needed

Include a rename for UUID vs PG_UUID.

Signed-off-by: Mike Fiedler <[email protected]>

* Step 5 - extract a common type

Since we use the pattern of a "bool, default false", we can extract it
to a custom type, and use it in the `Mapped` annotation, and slim down
the repeated definitions.

If this proves a useful type to other modules, it can be refactored out
to a utility location.

Signed-off-by: Mike Fiedler <[email protected]>

---------

Signed-off-by: Mike Fiedler <[email protected]>
  • Loading branch information
miketheman authored Aug 3, 2023
1 parent e0c6a58 commit f4f80c3
Show file tree
Hide file tree
Showing 21 changed files with 406 additions and 397 deletions.
51 changes: 21 additions & 30 deletions tests/unit/cli/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sqlalchemy

from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import mapped_column

import warehouse.cli.db.dbml
import warehouse.db
Expand Down Expand Up @@ -337,38 +337,35 @@ def test_dbml_command(monkeypatch, cli):
def test_generate_dbml_file(tmp_path_factory):
class Muddle(warehouse.db.Model):
__abstract__ = True

metadata = sqlalchemy.MetaData()

Muddle = declarative_base(cls=Muddle, metadata=metadata) # noqa, type: ignore
metadata = sqlalchemy.MetaData()

class Clan(Muddle):
__tablename__ = "_clan"
__table_args__ = {"comment": "various clans"}

name = sqlalchemy.Column(sqlalchemy.Text, unique=True, nullable=False)
fetched = sqlalchemy.Column(
name = mapped_column(sqlalchemy.Text, unique=True, nullable=False)
fetched = mapped_column(
sqlalchemy.Text,
server_default=sqlalchemy.FetchedValue(),
comment="fetched value",
)
for_the_children = sqlalchemy.Column(sqlalchemy.Boolean, default=True)
nice = sqlalchemy.Column(sqlalchemy.String(length=69))
for_the_children = mapped_column(sqlalchemy.Boolean, default=True)
nice = mapped_column(sqlalchemy.String(length=69))

class ClanMember(Muddle):
__tablename__ = "_clan_member"

name = sqlalchemy.Column(sqlalchemy.Text, nullable=False)
clan_id = sqlalchemy.Column(
name = mapped_column(sqlalchemy.Text, nullable=False)
clan_id = mapped_column(
UUID(as_uuid=True),
sqlalchemy.ForeignKey("_clan.id", deferrable=True, initially="DEFERRED"),
)
joined = sqlalchemy.Column(
joined = mapped_column(
sqlalchemy.DateTime,
nullable=False,
server_default=sqlalchemy.sql.func.now(),
)
departed = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
departed = mapped_column(sqlalchemy.DateTime, nullable=True)

outpath = tmp_path_factory.mktemp("out") / "wutang.dbml"
warehouse.cli.db.dbml.generate_dbml_file(Muddle.metadata.tables.values(), outpath)
Expand All @@ -380,38 +377,35 @@ class ClanMember(Muddle):
def test_generate_dbml_console(capsys, monkeypatch):
class Muddle(warehouse.db.Model):
__abstract__ = True

metadata = sqlalchemy.MetaData()

Muddle = declarative_base(cls=Muddle, metadata=metadata) # noqa, type: ignore
metadata = sqlalchemy.MetaData()

class Clan(Muddle):
__tablename__ = "_clan"
__table_args__ = {"comment": "various clans"}

name = sqlalchemy.Column(sqlalchemy.Text, unique=True, nullable=False)
fetched = sqlalchemy.Column(
name = mapped_column(sqlalchemy.Text, unique=True, nullable=False)
fetched = mapped_column(
sqlalchemy.Text,
server_default=sqlalchemy.FetchedValue(),
comment="fetched value",
)
for_the_children = sqlalchemy.Column(sqlalchemy.Boolean, default=True)
nice = sqlalchemy.Column(sqlalchemy.String(length=69))
for_the_children = mapped_column(sqlalchemy.Boolean, default=True)
nice = mapped_column(sqlalchemy.String(length=69))

class ClanMember(Muddle):
__tablename__ = "_clan_member"

name = sqlalchemy.Column(sqlalchemy.Text, nullable=False)
clan_id = sqlalchemy.Column(
name = mapped_column(sqlalchemy.Text, nullable=False)
clan_id = mapped_column(
UUID(as_uuid=True),
sqlalchemy.ForeignKey("_clan.id", deferrable=True, initially="DEFERRED"),
)
joined = sqlalchemy.Column(
joined = mapped_column(
sqlalchemy.DateTime,
nullable=False,
server_default=sqlalchemy.sql.func.now(),
)
departed = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
departed = mapped_column(sqlalchemy.DateTime, nullable=True)

warehouse.cli.db.dbml.generate_dbml_file(Muddle.metadata.tables.values(), None)
captured = capsys.readouterr()
Expand All @@ -422,10 +416,7 @@ class ClanMember(Muddle):
def test_generate_dbml_bad_conversion():
class Muddle(warehouse.db.Model):
__abstract__ = True

metadata = sqlalchemy.MetaData()

Muddle = declarative_base(cls=Muddle, metadata=metadata) # noqa, type: ignore
metadata = sqlalchemy.MetaData()

class BadText(sqlalchemy.Text):
pass
Expand All @@ -434,7 +425,7 @@ class Puddle(Muddle):
__tablename__ = "puddle"
__table_args__ = {"comment": "various clans"}

name = sqlalchemy.Column(BadText, unique=True, nullable=False)
name = mapped_column(BadText, unique=True, nullable=False)

with pytest.raises(SystemExit):
warehouse.cli.db.dbml.generate_dbml_file(Muddle.metadata.tables.values(), None)
4 changes: 2 additions & 2 deletions tests/unit/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def inspect(item):

original_repr = model.__repr__

assert repr(model) == "Base(foo={})".format(repr("bar"))
assert repr(model) == "ModelBase(foo={})".format(repr("bar"))
assert inspect.calls == [pretend.call(model)]
assert model.__repr__ is not original_repr
assert repr(model) == "Base(foo={})".format(repr("bar"))
assert repr(model) == "ModelBase(foo={})".format(repr("bar"))


def test_listens_for(monkeypatch):
Expand Down
129 changes: 75 additions & 54 deletions warehouse/accounts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import datetime
import enum

from typing import TYPE_CHECKING, Annotated
from uuid import UUID

from pyramid.authorization import Allow, Authenticated
from sqlalchemy import (
Boolean,
CheckConstraint,
Column,
DateTime,
Enum,
ForeignKey,
Expand All @@ -31,16 +34,25 @@
select,
sql,
)
from sqlalchemy.dialects.postgresql import CITEXT, UUID
from sqlalchemy.dialects.postgresql import CITEXT, UUID as PG_UUID
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import Mapped, mapped_column

from warehouse import db
from warehouse.events.models import HasEvents
from warehouse.sitemap.models import SitemapMixin
from warehouse.utils.attrs import make_repr
from warehouse.utils.db.types import TZDateTime

if TYPE_CHECKING:
from warehouse.macaroons.models import Macaroon
from warehouse.oidc.models import PendingOIDCPublisher


# Custom column types
bool_false = Annotated[bool, mapped_column(Boolean, server_default=sql.false())]


class UserFactory:
def __init__(self, request):
Expand Down Expand Up @@ -70,49 +82,53 @@ class User(SitemapMixin, HasEvents, db.Model):

__repr__ = make_repr("username")

username = Column(CITEXT, nullable=False, unique=True)
name = Column(String(length=100), nullable=False)
password = Column(String(length=128), nullable=False)
password_date = Column(TZDateTime, nullable=True, server_default=sql.func.now())
is_active = Column(Boolean, nullable=False, server_default=sql.false())
is_frozen = Column(Boolean, nullable=False, server_default=sql.false())
is_superuser = Column(Boolean, nullable=False, server_default=sql.false())
is_moderator = Column(Boolean, nullable=False, server_default=sql.false())
is_psf_staff = Column(Boolean, nullable=False, server_default=sql.false())
prohibit_password_reset = Column(
Boolean, nullable=False, server_default=sql.false()
username: Mapped[CITEXT] = mapped_column(CITEXT, unique=True)
name: Mapped[str] = mapped_column(String(length=100))
password: Mapped[str] = mapped_column(String(length=128))
password_date: Mapped[datetime.datetime | None] = mapped_column(
TZDateTime, server_default=sql.func.now()
)
is_active: Mapped[bool_false]
is_frozen: Mapped[bool_false]
is_superuser: Mapped[bool_false]
is_moderator: Mapped[bool_false]
is_psf_staff: Mapped[bool_false]
prohibit_password_reset: Mapped[bool_false]
hide_avatar: Mapped[bool_false]
date_joined: Mapped[datetime.datetime | None] = mapped_column(
DateTime,
server_default=sql.func.now(),
)
hide_avatar = Column(Boolean, nullable=False, server_default=sql.false())
date_joined = Column(DateTime, server_default=sql.func.now())
last_login = Column(TZDateTime, nullable=True, server_default=sql.func.now())
disabled_for = Column( # type: ignore[var-annotated]
last_login: Mapped[datetime.datetime | None] = mapped_column(
TZDateTime, server_default=sql.func.now()
)
disabled_for: Mapped[Enum | None] = mapped_column(
Enum(DisableReason, values_callable=lambda x: [e.value for e in x]),
nullable=True,
)

totp_secret = Column(LargeBinary(length=20), nullable=True)
last_totp_value = Column(String, nullable=True)
totp_secret: Mapped[int | None] = mapped_column(LargeBinary(length=20))
last_totp_value: Mapped[str | None]

webauthn = orm.relationship(
webauthn: Mapped[list[WebAuthn]] = orm.relationship(
"WebAuthn", backref="user", cascade="all, delete-orphan", lazy=True
)

recovery_codes = orm.relationship(
recovery_codes: Mapped[list[RecoveryCode]] = orm.relationship(
"RecoveryCode", backref="user", cascade="all, delete-orphan", lazy="dynamic"
)

emails = orm.relationship(
emails: Mapped[list[Email]] = orm.relationship(
"Email", backref="user", cascade="all, delete-orphan", lazy=False
)

macaroons = orm.relationship(
macaroons: Mapped[list[Macaroon]] = orm.relationship(
"Macaroon",
cascade="all, delete-orphan",
lazy=True,
order_by="Macaroon.created.desc()",
)

pending_oidc_publishers = orm.relationship(
pending_oidc_publishers: Mapped[list[PendingOIDCPublisher]] = orm.relationship(
"PendingOIDCPublisher",
backref="added_by",
cascade="all, delete-orphan",
Expand Down Expand Up @@ -225,30 +241,32 @@ class WebAuthn(db.Model):
UniqueConstraint("label", "user_id", name="_user_security_keys_label_uc"),
)

user_id = Column(
UUID(as_uuid=True),
user_id: Mapped[UUID] = mapped_column(
PG_UUID(as_uuid=True),
ForeignKey("users.id", deferrable=True, initially="DEFERRED"),
nullable=False,
index=True,
)
label = Column(String, nullable=False)
credential_id = Column(String, unique=True, nullable=False)
public_key = Column(String, unique=True, nullable=True)
sign_count = Column(Integer, default=0)
label: Mapped[str]
credential_id: Mapped[str] = mapped_column(String, unique=True)
public_key: Mapped[str | None] = mapped_column(String, unique=True)
sign_count: Mapped[int | None] = mapped_column(Integer, default=0)


class RecoveryCode(db.Model):
__tablename__ = "user_recovery_codes"

user_id = Column(
UUID(as_uuid=True),
user_id: Mapped[UUID] = mapped_column(
PG_UUID(as_uuid=True),
ForeignKey("users.id", deferrable=True, initially="DEFERRED"),
nullable=False,
index=True,
)
code = Column(String(length=128), nullable=False)
generated = Column(DateTime, nullable=False, server_default=sql.func.now())
burned = Column(DateTime, nullable=True)
code: Mapped[str] = mapped_column(String(length=128))
generated: Mapped[datetime.datetime] = mapped_column(
DateTime, server_default=sql.func.now()
)
burned: Mapped[datetime.datetime | None]


class UnverifyReasons(enum.Enum):
Expand All @@ -264,23 +282,23 @@ class Email(db.ModelBase):
Index("user_emails_user_id", "user_id"),
)

id = Column(Integer, primary_key=True, nullable=False)
user_id = Column(
UUID(as_uuid=True),
id: Mapped[int] = mapped_column(Integer, primary_key=True)
user_id: Mapped[UUID] = mapped_column(
PG_UUID(as_uuid=True),
ForeignKey("users.id", deferrable=True, initially="DEFERRED"),
nullable=False,
)
email = Column(String(length=254), nullable=False)
primary = Column(Boolean, nullable=False)
verified = Column(Boolean, nullable=False)
public = Column(Boolean, nullable=False, server_default=sql.false())
email: Mapped[str] = mapped_column(String(length=254))
primary: Mapped[bool]
verified: Mapped[bool]
public: Mapped[bool_false]

# Deliverability information
unverify_reason = Column( # type: ignore[var-annotated]
unverify_reason: Mapped[Enum | None] = mapped_column(
Enum(UnverifyReasons, values_callable=lambda x: [e.value for e in x]),
nullable=True,
)
transient_bounces = Column(Integer, nullable=False, server_default=sql.text("0"))
transient_bounces: Mapped[int] = mapped_column(
Integer, server_default=sql.text("0")
)


class ProhibitedUserName(db.Model):
Expand All @@ -297,12 +315,15 @@ class ProhibitedUserName(db.Model):

__repr__ = make_repr("name")

created = Column(
DateTime(timezone=False), nullable=False, server_default=sql.func.now()
created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=False), server_default=sql.func.now()
)
name = Column(Text, unique=True, nullable=False)
_prohibited_by = Column(
"prohibited_by", UUID(as_uuid=True), ForeignKey("users.id"), index=True
name: Mapped[str] = mapped_column(Text, unique=True)
_prohibited_by: Mapped[UUID | None] = mapped_column(
"prohibited_by",
PG_UUID(as_uuid=True),
ForeignKey("users.id"),
index=True,
)
prohibited_by = orm.relationship(User)
comment = Column(Text, nullable=False, server_default="")
prohibited_by: Mapped[User] = orm.relationship(User)
comment: Mapped[str] = mapped_column(Text, server_default="")
11 changes: 6 additions & 5 deletions warehouse/admin/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

import enum

from sqlalchemy import Boolean, Column, Text, sql
from sqlalchemy import Boolean, Text, sql
from sqlalchemy.orm import mapped_column

from warehouse import db

Expand All @@ -32,10 +33,10 @@ class AdminFlagValue(enum.Enum):
class AdminFlag(db.ModelBase):
__tablename__ = "admin_flags"

id = Column(Text, primary_key=True, nullable=False)
description = Column(Text, nullable=False)
enabled = Column(Boolean, nullable=False)
notify = Column(Boolean, nullable=False, server_default=sql.false())
id = mapped_column(Text, primary_key=True, nullable=False)
description = mapped_column(Text, nullable=False)
enabled = mapped_column(Boolean, nullable=False)
notify = mapped_column(Boolean, nullable=False, server_default=sql.false())


class Flags:
Expand Down
Loading

0 comments on commit f4f80c3

Please sign in to comment.