From f4f80c3d84f636c60390b9a1d8c71c6c75da719e Mon Sep 17 00:00:00 2001 From: Mike Fiedler Date: Thu, 3 Aug 2023 08:52:34 -0400 Subject: [PATCH] SQLAlchemy 2.0 Declarative Syntax (#14266) * Step 1 of moving to Declarative syntax Refs: https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#step-one-orm-declarative-base-is-superseded-by-orm-declarativebase - Replace abstract concrete base with simpler pattern Since we don't actually care about the polymorphism, we can simplify and have the models map correctly. Follows pattern as shown in the examples: https://docs.sqlalchemy.org/en/20/_modules/examples/generic_associations/table_per_related.html - Minor modification to Index definition so we don't generate new indices as a result. - Update DBML tests to use the base model, override metadata Signed-off-by: Mike Fiedler * Step 2 of moving to Declarative syntax In this step, we find&replace instances of `Column()` with `mapped_column()` which presents an identical API. The columns are still `Mapped[Any]` at this stage, those should be added in Step 3. Refs: https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#step-two-replace-declarative-use-of-schema-column-with-orm-mapped-column Two Tables have not been converted to the declarative syntax, as they do not have a primary key set yet. That needs to be handled out of band, as there's data quality/migrations that need to happen first. Signed-off-by: Mike Fiedler * Step 3 - accounts models Apply changes to the models in the Accounts segment. https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#step-three-apply-exact-python-types-as-needed-using-orm-mapped Adds verbosity that should be simplified in Step 4. Signed-off-by: Mike Fiedler * fix: apply nullable changes to match database Towards Step 4, apply these qualifiers so that autogenerated migrations perform no changes. Signed-off-by: Mike Fiedler * Step 4 - clean up mapped_column for account models Now that we're past mapping the datatypes, we can simplify the declarations by removing unnecessary details. Refs: https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#step-four-remove-orm-mapped-column-directives-where-no-longer-needed Include a rename for UUID vs PG_UUID. Signed-off-by: Mike Fiedler * Step 5 - extract a common type Since we use the pattern of a "bool, default false", we can extract it to a custom type, and use it in the `Mapped` annotation, and slim down the repeated definitions. If this proves a useful type to other modules, it can be refactored out to a utility location. Signed-off-by: Mike Fiedler --------- Signed-off-by: Mike Fiedler --- tests/unit/cli/test_db.py | 51 +++--- tests/unit/test_db.py | 4 +- warehouse/accounts/models.py | 129 ++++++++------ warehouse/admin/flags.py | 11 +- warehouse/banners/models.py | 17 +- warehouse/classifiers/models.py | 9 +- warehouse/db.py | 22 +-- warehouse/email/ses/models.py | 27 +-- warehouse/events/models.py | 80 +++------ .../integrations/vulnerabilities/models.py | 18 +- warehouse/ip_addresses/models.py | 25 +-- warehouse/macaroons/models.py | 20 ++- warehouse/oidc/models/_core.py | 13 +- warehouse/oidc/models/github.py | 20 ++- warehouse/oidc/models/google.py | 14 +- warehouse/organizations/models.py | 77 ++++---- warehouse/packaging/models.py | 167 +++++++++--------- warehouse/sitemap/models.py | 5 +- warehouse/sponsors/models.py | 35 ++-- warehouse/subscriptions/models.py | 52 +++--- warehouse/utils/row_counter.py | 7 +- 21 files changed, 406 insertions(+), 397 deletions(-) diff --git a/tests/unit/cli/test_db.py b/tests/unit/cli/test_db.py index 5f27db71dc89..1a80f15590c8 100644 --- a/tests/unit/cli/test_db.py +++ b/tests/unit/cli/test_db.py @@ -16,7 +16,7 @@ import sqlalchemy from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import mapped_column import warehouse.cli.db.dbml import warehouse.db @@ -337,38 +337,35 @@ def test_dbml_command(monkeypatch, cli): def test_generate_dbml_file(tmp_path_factory): class Muddle(warehouse.db.Model): __abstract__ = True - - metadata = sqlalchemy.MetaData() - - Muddle = declarative_base(cls=Muddle, metadata=metadata) # noqa, type: ignore + metadata = sqlalchemy.MetaData() class Clan(Muddle): __tablename__ = "_clan" __table_args__ = {"comment": "various clans"} - name = sqlalchemy.Column(sqlalchemy.Text, unique=True, nullable=False) - fetched = sqlalchemy.Column( + name = mapped_column(sqlalchemy.Text, unique=True, nullable=False) + fetched = mapped_column( sqlalchemy.Text, server_default=sqlalchemy.FetchedValue(), comment="fetched value", ) - for_the_children = sqlalchemy.Column(sqlalchemy.Boolean, default=True) - nice = sqlalchemy.Column(sqlalchemy.String(length=69)) + for_the_children = mapped_column(sqlalchemy.Boolean, default=True) + nice = mapped_column(sqlalchemy.String(length=69)) class ClanMember(Muddle): __tablename__ = "_clan_member" - name = sqlalchemy.Column(sqlalchemy.Text, nullable=False) - clan_id = sqlalchemy.Column( + name = mapped_column(sqlalchemy.Text, nullable=False) + clan_id = mapped_column( UUID(as_uuid=True), sqlalchemy.ForeignKey("_clan.id", deferrable=True, initially="DEFERRED"), ) - joined = sqlalchemy.Column( + joined = mapped_column( sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.sql.func.now(), ) - departed = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True) + departed = mapped_column(sqlalchemy.DateTime, nullable=True) outpath = tmp_path_factory.mktemp("out") / "wutang.dbml" warehouse.cli.db.dbml.generate_dbml_file(Muddle.metadata.tables.values(), outpath) @@ -380,38 +377,35 @@ class ClanMember(Muddle): def test_generate_dbml_console(capsys, monkeypatch): class Muddle(warehouse.db.Model): __abstract__ = True - - metadata = sqlalchemy.MetaData() - - Muddle = declarative_base(cls=Muddle, metadata=metadata) # noqa, type: ignore + metadata = sqlalchemy.MetaData() class Clan(Muddle): __tablename__ = "_clan" __table_args__ = {"comment": "various clans"} - name = sqlalchemy.Column(sqlalchemy.Text, unique=True, nullable=False) - fetched = sqlalchemy.Column( + name = mapped_column(sqlalchemy.Text, unique=True, nullable=False) + fetched = mapped_column( sqlalchemy.Text, server_default=sqlalchemy.FetchedValue(), comment="fetched value", ) - for_the_children = sqlalchemy.Column(sqlalchemy.Boolean, default=True) - nice = sqlalchemy.Column(sqlalchemy.String(length=69)) + for_the_children = mapped_column(sqlalchemy.Boolean, default=True) + nice = mapped_column(sqlalchemy.String(length=69)) class ClanMember(Muddle): __tablename__ = "_clan_member" - name = sqlalchemy.Column(sqlalchemy.Text, nullable=False) - clan_id = sqlalchemy.Column( + name = mapped_column(sqlalchemy.Text, nullable=False) + clan_id = mapped_column( UUID(as_uuid=True), sqlalchemy.ForeignKey("_clan.id", deferrable=True, initially="DEFERRED"), ) - joined = sqlalchemy.Column( + joined = mapped_column( sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.sql.func.now(), ) - departed = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True) + departed = mapped_column(sqlalchemy.DateTime, nullable=True) warehouse.cli.db.dbml.generate_dbml_file(Muddle.metadata.tables.values(), None) captured = capsys.readouterr() @@ -422,10 +416,7 @@ class ClanMember(Muddle): def test_generate_dbml_bad_conversion(): class Muddle(warehouse.db.Model): __abstract__ = True - - metadata = sqlalchemy.MetaData() - - Muddle = declarative_base(cls=Muddle, metadata=metadata) # noqa, type: ignore + metadata = sqlalchemy.MetaData() class BadText(sqlalchemy.Text): pass @@ -434,7 +425,7 @@ class Puddle(Muddle): __tablename__ = "puddle" __table_args__ = {"comment": "various clans"} - name = sqlalchemy.Column(BadText, unique=True, nullable=False) + name = mapped_column(BadText, unique=True, nullable=False) with pytest.raises(SystemExit): warehouse.cli.db.dbml.generate_dbml_file(Muddle.metadata.tables.values(), None) diff --git a/tests/unit/test_db.py b/tests/unit/test_db.py index 453c8d1c4938..ba8c4101286b 100644 --- a/tests/unit/test_db.py +++ b/tests/unit/test_db.py @@ -47,10 +47,10 @@ def inspect(item): original_repr = model.__repr__ - assert repr(model) == "Base(foo={})".format(repr("bar")) + assert repr(model) == "ModelBase(foo={})".format(repr("bar")) assert inspect.calls == [pretend.call(model)] assert model.__repr__ is not original_repr - assert repr(model) == "Base(foo={})".format(repr("bar")) + assert repr(model) == "ModelBase(foo={})".format(repr("bar")) def test_listens_for(monkeypatch): diff --git a/warehouse/accounts/models.py b/warehouse/accounts/models.py index 556cf96c3ad3..bb0de797dfe8 100644 --- a/warehouse/accounts/models.py +++ b/warehouse/accounts/models.py @@ -9,15 +9,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import datetime import enum +from typing import TYPE_CHECKING, Annotated +from uuid import UUID + from pyramid.authorization import Allow, Authenticated from sqlalchemy import ( Boolean, CheckConstraint, - Column, DateTime, Enum, ForeignKey, @@ -31,9 +34,10 @@ select, sql, ) -from sqlalchemy.dialects.postgresql import CITEXT, UUID +from sqlalchemy.dialects.postgresql import CITEXT, UUID as PG_UUID from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import Mapped, mapped_column from warehouse import db from warehouse.events.models import HasEvents @@ -41,6 +45,14 @@ from warehouse.utils.attrs import make_repr from warehouse.utils.db.types import TZDateTime +if TYPE_CHECKING: + from warehouse.macaroons.models import Macaroon + from warehouse.oidc.models import PendingOIDCPublisher + + +# Custom column types +bool_false = Annotated[bool, mapped_column(Boolean, server_default=sql.false())] + class UserFactory: def __init__(self, request): @@ -70,49 +82,53 @@ class User(SitemapMixin, HasEvents, db.Model): __repr__ = make_repr("username") - username = Column(CITEXT, nullable=False, unique=True) - name = Column(String(length=100), nullable=False) - password = Column(String(length=128), nullable=False) - password_date = Column(TZDateTime, nullable=True, server_default=sql.func.now()) - is_active = Column(Boolean, nullable=False, server_default=sql.false()) - is_frozen = Column(Boolean, nullable=False, server_default=sql.false()) - is_superuser = Column(Boolean, nullable=False, server_default=sql.false()) - is_moderator = Column(Boolean, nullable=False, server_default=sql.false()) - is_psf_staff = Column(Boolean, nullable=False, server_default=sql.false()) - prohibit_password_reset = Column( - Boolean, nullable=False, server_default=sql.false() + username: Mapped[CITEXT] = mapped_column(CITEXT, unique=True) + name: Mapped[str] = mapped_column(String(length=100)) + password: Mapped[str] = mapped_column(String(length=128)) + password_date: Mapped[datetime.datetime | None] = mapped_column( + TZDateTime, server_default=sql.func.now() + ) + is_active: Mapped[bool_false] + is_frozen: Mapped[bool_false] + is_superuser: Mapped[bool_false] + is_moderator: Mapped[bool_false] + is_psf_staff: Mapped[bool_false] + prohibit_password_reset: Mapped[bool_false] + hide_avatar: Mapped[bool_false] + date_joined: Mapped[datetime.datetime | None] = mapped_column( + DateTime, + server_default=sql.func.now(), ) - hide_avatar = Column(Boolean, nullable=False, server_default=sql.false()) - date_joined = Column(DateTime, server_default=sql.func.now()) - last_login = Column(TZDateTime, nullable=True, server_default=sql.func.now()) - disabled_for = Column( # type: ignore[var-annotated] + last_login: Mapped[datetime.datetime | None] = mapped_column( + TZDateTime, server_default=sql.func.now() + ) + disabled_for: Mapped[Enum | None] = mapped_column( Enum(DisableReason, values_callable=lambda x: [e.value for e in x]), - nullable=True, ) - totp_secret = Column(LargeBinary(length=20), nullable=True) - last_totp_value = Column(String, nullable=True) + totp_secret: Mapped[int | None] = mapped_column(LargeBinary(length=20)) + last_totp_value: Mapped[str | None] - webauthn = orm.relationship( + webauthn: Mapped[list[WebAuthn]] = orm.relationship( "WebAuthn", backref="user", cascade="all, delete-orphan", lazy=True ) - recovery_codes = orm.relationship( + recovery_codes: Mapped[list[RecoveryCode]] = orm.relationship( "RecoveryCode", backref="user", cascade="all, delete-orphan", lazy="dynamic" ) - emails = orm.relationship( + emails: Mapped[list[Email]] = orm.relationship( "Email", backref="user", cascade="all, delete-orphan", lazy=False ) - macaroons = orm.relationship( + macaroons: Mapped[list[Macaroon]] = orm.relationship( "Macaroon", cascade="all, delete-orphan", lazy=True, order_by="Macaroon.created.desc()", ) - pending_oidc_publishers = orm.relationship( + pending_oidc_publishers: Mapped[list[PendingOIDCPublisher]] = orm.relationship( "PendingOIDCPublisher", backref="added_by", cascade="all, delete-orphan", @@ -225,30 +241,32 @@ class WebAuthn(db.Model): UniqueConstraint("label", "user_id", name="_user_security_keys_label_uc"), ) - user_id = Column( - UUID(as_uuid=True), + user_id: Mapped[UUID] = mapped_column( + PG_UUID(as_uuid=True), ForeignKey("users.id", deferrable=True, initially="DEFERRED"), nullable=False, index=True, ) - label = Column(String, nullable=False) - credential_id = Column(String, unique=True, nullable=False) - public_key = Column(String, unique=True, nullable=True) - sign_count = Column(Integer, default=0) + label: Mapped[str] + credential_id: Mapped[str] = mapped_column(String, unique=True) + public_key: Mapped[str | None] = mapped_column(String, unique=True) + sign_count: Mapped[int | None] = mapped_column(Integer, default=0) class RecoveryCode(db.Model): __tablename__ = "user_recovery_codes" - user_id = Column( - UUID(as_uuid=True), + user_id: Mapped[UUID] = mapped_column( + PG_UUID(as_uuid=True), ForeignKey("users.id", deferrable=True, initially="DEFERRED"), nullable=False, index=True, ) - code = Column(String(length=128), nullable=False) - generated = Column(DateTime, nullable=False, server_default=sql.func.now()) - burned = Column(DateTime, nullable=True) + code: Mapped[str] = mapped_column(String(length=128)) + generated: Mapped[datetime.datetime] = mapped_column( + DateTime, server_default=sql.func.now() + ) + burned: Mapped[datetime.datetime | None] class UnverifyReasons(enum.Enum): @@ -264,23 +282,23 @@ class Email(db.ModelBase): Index("user_emails_user_id", "user_id"), ) - id = Column(Integer, primary_key=True, nullable=False) - user_id = Column( - UUID(as_uuid=True), + id: Mapped[int] = mapped_column(Integer, primary_key=True) + user_id: Mapped[UUID] = mapped_column( + PG_UUID(as_uuid=True), ForeignKey("users.id", deferrable=True, initially="DEFERRED"), - nullable=False, ) - email = Column(String(length=254), nullable=False) - primary = Column(Boolean, nullable=False) - verified = Column(Boolean, nullable=False) - public = Column(Boolean, nullable=False, server_default=sql.false()) + email: Mapped[str] = mapped_column(String(length=254)) + primary: Mapped[bool] + verified: Mapped[bool] + public: Mapped[bool_false] # Deliverability information - unverify_reason = Column( # type: ignore[var-annotated] + unverify_reason: Mapped[Enum | None] = mapped_column( Enum(UnverifyReasons, values_callable=lambda x: [e.value for e in x]), - nullable=True, ) - transient_bounces = Column(Integer, nullable=False, server_default=sql.text("0")) + transient_bounces: Mapped[int] = mapped_column( + Integer, server_default=sql.text("0") + ) class ProhibitedUserName(db.Model): @@ -297,12 +315,15 @@ class ProhibitedUserName(db.Model): __repr__ = make_repr("name") - created = Column( - DateTime(timezone=False), nullable=False, server_default=sql.func.now() + created: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=False), server_default=sql.func.now() ) - name = Column(Text, unique=True, nullable=False) - _prohibited_by = Column( - "prohibited_by", UUID(as_uuid=True), ForeignKey("users.id"), index=True + name: Mapped[str] = mapped_column(Text, unique=True) + _prohibited_by: Mapped[UUID | None] = mapped_column( + "prohibited_by", + PG_UUID(as_uuid=True), + ForeignKey("users.id"), + index=True, ) - prohibited_by = orm.relationship(User) - comment = Column(Text, nullable=False, server_default="") + prohibited_by: Mapped[User] = orm.relationship(User) + comment: Mapped[str] = mapped_column(Text, server_default="") diff --git a/warehouse/admin/flags.py b/warehouse/admin/flags.py index b3ad238fb8d5..c8597dd0b781 100644 --- a/warehouse/admin/flags.py +++ b/warehouse/admin/flags.py @@ -12,7 +12,8 @@ import enum -from sqlalchemy import Boolean, Column, Text, sql +from sqlalchemy import Boolean, Text, sql +from sqlalchemy.orm import mapped_column from warehouse import db @@ -32,10 +33,10 @@ class AdminFlagValue(enum.Enum): class AdminFlag(db.ModelBase): __tablename__ = "admin_flags" - id = Column(Text, primary_key=True, nullable=False) - description = Column(Text, nullable=False) - enabled = Column(Boolean, nullable=False) - notify = Column(Boolean, nullable=False, server_default=sql.false()) + id = mapped_column(Text, primary_key=True, nullable=False) + description = mapped_column(Text, nullable=False) + enabled = mapped_column(Boolean, nullable=False) + notify = mapped_column(Boolean, nullable=False, server_default=sql.false()) class Flags: diff --git a/warehouse/banners/models.py b/warehouse/banners/models.py index d7cf39f2e4b9..d24bf8141f3d 100644 --- a/warehouse/banners/models.py +++ b/warehouse/banners/models.py @@ -11,7 +11,8 @@ # limitations under the License. from datetime import date -from sqlalchemy import Boolean, Column, Date, String, Text +from sqlalchemy import Boolean, Date, String, Text +from sqlalchemy.orm import mapped_column from warehouse import db from warehouse.utils.attrs import make_repr @@ -24,17 +25,17 @@ class Banner(db.Model): DEFAULT_BTN_LABEL = "See more" # internal name - name = Column(String, nullable=False) + name = mapped_column(String, nullable=False) # banner display configuration - text = Column(Text, nullable=False) - link_url = Column(Text, nullable=False) - link_label = Column(String, nullable=False, default=DEFAULT_BTN_LABEL) - fa_icon = Column(String, nullable=False, default=DEFAULT_FA_ICON) + text = mapped_column(Text, nullable=False) + link_url = mapped_column(Text, nullable=False) + link_label = mapped_column(String, nullable=False, default=DEFAULT_BTN_LABEL) + fa_icon = mapped_column(String, nullable=False, default=DEFAULT_FA_ICON) # visibility control - active = Column(Boolean, nullable=False, default=False) - end = Column(Date, nullable=False) + active = mapped_column(Boolean, nullable=False, default=False) + end = mapped_column(Date, nullable=False) @property def is_live(self): diff --git a/warehouse/classifiers/models.py b/warehouse/classifiers/models.py index c3ef10d64e0b..498b21a484c7 100644 --- a/warehouse/classifiers/models.py +++ b/warehouse/classifiers/models.py @@ -10,7 +10,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from sqlalchemy import CheckConstraint, Column, Integer, Text +from sqlalchemy import CheckConstraint, Integer, Text +from sqlalchemy.orm import mapped_column from warehouse import db from warehouse.utils.attrs import make_repr @@ -25,6 +26,6 @@ class Classifier(db.ModelBase): __repr__ = make_repr("classifier") - id = Column(Integer, primary_key=True, nullable=False) - classifier = Column(Text, unique=True) - ordering = Column(Integer, nullable=True) + id = mapped_column(Integer, primary_key=True, nullable=False) + classifier = mapped_column(Text, unique=True) + ordering = mapped_column(Integer, nullable=True) diff --git a/warehouse/db.py b/warehouse/db.py index 4df7be5b8a82..b30eba2e923e 100644 --- a/warehouse/db.py +++ b/warehouse/db.py @@ -22,7 +22,7 @@ from sqlalchemy import event, inspect from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.exc import IntegrityError, OperationalError -from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy.orm import DeclarativeBase, mapped_column, sessionmaker from warehouse.metrics import IMetricsService from warehouse.utils.attrs import make_repr @@ -63,7 +63,15 @@ class DatabaseNotAvailableError(Exception): ... -class ModelBase: +# The Global metadata object. +metadata = sqlalchemy.MetaData() + + +class ModelBase(DeclarativeBase): + """Base class for models using declarative syntax.""" + + metadata = metadata + def __repr__(self): inst = inspect(self) self.__repr__ = make_repr( @@ -72,18 +80,10 @@ def __repr__(self): return self.__repr__() -# The Global metadata object. -metadata = sqlalchemy.MetaData() - - -# Base class for models using declarative syntax -ModelBase = declarative_base(cls=ModelBase, metadata=metadata) # type: ignore - - class Model(ModelBase): __abstract__ = True - id = sqlalchemy.Column( + id = mapped_column( UUID(as_uuid=True), primary_key=True, server_default=sqlalchemy.text("gen_random_uuid()"), diff --git a/warehouse/email/ses/models.py b/warehouse/email/ses/models.py index 706d9c67b7a6..7d504ae18cb3 100644 --- a/warehouse/email/ses/models.py +++ b/warehouse/email/ses/models.py @@ -14,9 +14,10 @@ import automat -from sqlalchemy import Boolean, Column, DateTime, Enum, ForeignKey, Text, orm, sql +from sqlalchemy import Boolean, DateTime, Enum, ForeignKey, Text, orm, sql from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.orm import mapped_column from sqlalchemy.orm.session import object_session from warehouse import db @@ -229,18 +230,18 @@ def _get_email(self): class EmailMessage(db.Model): __tablename__ = "ses_emails" - created = Column(DateTime, nullable=False, server_default=sql.func.now()) - status = Column( # type: ignore[var-annotated] + created = mapped_column(DateTime, nullable=False, server_default=sql.func.now()) + status = mapped_column( Enum(EmailStatuses, values_callable=lambda x: [e.value for e in x]), nullable=False, server_default=EmailStatuses.Accepted.value, ) - message_id = Column(Text, nullable=False, unique=True, index=True) - from_ = Column("from", Text, nullable=False) - to = Column(Text, nullable=False, index=True) - subject = Column(Text, nullable=False) - missing = Column(Boolean, nullable=False, server_default=sql.false()) + message_id = mapped_column(Text, nullable=False, unique=True, index=True) + from_ = mapped_column("from", Text, nullable=False) + to = mapped_column(Text, nullable=False, index=True) + subject = mapped_column(Text, nullable=False) + missing = mapped_column(Boolean, nullable=False, server_default=sql.false()) # Relationships! events = orm.relationship( @@ -261,9 +262,9 @@ class EventTypes(enum.Enum): class Event(db.Model): __tablename__ = "ses_events" - created = Column(DateTime, nullable=False, server_default=sql.func.now()) + created = mapped_column(DateTime, nullable=False, server_default=sql.func.now()) - email_id = Column( + email_id = mapped_column( UUID(as_uuid=True), ForeignKey( "ses_emails.id", deferrable=True, initially="DEFERRED", ondelete="CASCADE" @@ -272,11 +273,11 @@ class Event(db.Model): index=True, ) - event_id = Column(Text, nullable=False, unique=True, index=True) - event_type = Column( # type: ignore[var-annotated] + event_id = mapped_column(Text, nullable=False, unique=True, index=True) + event_type = mapped_column( Enum(EventTypes, values_callable=lambda x: [e.value for e in x]), nullable=False ) - data = Column( # type: ignore[var-annotated] + data = mapped_column( MutableDict.as_mutable(JSONB), nullable=False, server_default=sql.text("'{}'") # type: ignore[arg-type] # noqa: E501 ) diff --git a/warehouse/events/models.py b/warehouse/events/models.py index fe2f80af2d20..d4384410984d 100644 --- a/warehouse/events/models.py +++ b/warehouse/events/models.py @@ -16,10 +16,9 @@ from dataclasses import dataclass from linehaul.ua import parser as linehaul_user_agent_parser -from sqlalchemy import Column, DateTime, ForeignKey, Index, String, orm, sql +from sqlalchemy import DateTime, ForeignKey, Index, String, orm, sql from sqlalchemy.dialects.postgresql import JSONB, UUID -from sqlalchemy.ext.declarative import AbstractConcreteBase -from sqlalchemy.orm import declared_attr +from sqlalchemy.orm import declared_attr, mapped_column from ua_parser import user_agent_parser from warehouse import db @@ -115,52 +114,19 @@ def display(self) -> str: return "Unknown User-Agent" -class Event(AbstractConcreteBase): - tag = Column(String, nullable=False) - time = Column(DateTime, nullable=False, server_default=sql.func.now()) - additional = Column(JSONB, nullable=True) +class Event: + tag = mapped_column(String, nullable=False) + time = mapped_column(DateTime, nullable=False, server_default=sql.func.now()) + additional = mapped_column(JSONB, nullable=True) @declared_attr def ip_address_id(cls): # noqa: N805 - return Column( + return mapped_column( UUID(as_uuid=True), ForeignKey("ip_addresses.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=True, ) - @declared_attr - def __tablename__(cls): # noqa: N805 - return "_".join([cls.__name__.removesuffix("Event").lower(), "events"]) - - @declared_attr - def __table_args__(cls): # noqa: N805 - return (Index(f"ix_{ cls.__tablename__ }_source_id", "source_id"),) - - @declared_attr - def __mapper_args__(cls): # noqa: N805 - return ( - {"polymorphic_identity": cls.__name__, "concrete": True} - if cls.__name__ != "Event" - else {} - ) - - @declared_attr - def source_id(cls): # noqa: N805 - return Column( - UUID(as_uuid=True), - ForeignKey( - "%s.id" % cls._parent_class.__tablename__, - deferrable=True, - initially="DEFERRED", - ondelete="CASCADE", - ), - nullable=False, - ) - - @declared_attr - def source(cls): # noqa: N805 - return orm.relationship(cls._parent_class, back_populates="events") - @declared_attr def ip_address(cls): # noqa: N805 return orm.relationship(IpAddress) @@ -194,23 +160,33 @@ def user_agent_info(cls) -> str: # noqa: N805 return "No User-Agent" - def __init_subclass__(cls, /, parent_class, **kwargs): - cls._parent_class = parent_class - return cls - class HasEvents: Event: typing.ClassVar[type] - def __init_subclass__(cls, /, **kwargs): - super().__init_subclass__(**kwargs) - cls.Event = type( - f"{cls.__name__}Event", (Event, db.Model), dict(), parent_class=cls - ) - return cls - @declared_attr def events(cls): # noqa: N805 + cls.Event = type( + f"{cls.__name__}Event", + (Event, db.Model), + dict( + __tablename__=f"{cls.__name__.lower()}_events", + __table_args__=( + Index(f"ix_{cls.__name__.lower()}_events_source_id", "source_id"), + ), + source_id=mapped_column( + UUID(as_uuid=True), + ForeignKey( + f"{cls.__tablename__}.id", + deferrable=True, + initially="DEFERRED", + ondelete="CASCADE", + ), + nullable=False, + ), + source=orm.relationship(cls, back_populates="events"), + ), + ) return orm.relationship( cls.Event, cascade="all, delete-orphan", diff --git a/warehouse/integrations/vulnerabilities/models.py b/warehouse/integrations/vulnerabilities/models.py index c25b63c93f5a..c467a43a33c4 100644 --- a/warehouse/integrations/vulnerabilities/models.py +++ b/warehouse/integrations/vulnerabilities/models.py @@ -13,9 +13,11 @@ from sqlalchemy import Column, ForeignKey, ForeignKeyConstraint, String, Table, orm from sqlalchemy.dialects.postgresql import ARRAY, TIMESTAMP +from sqlalchemy.orm import mapped_column from warehouse import db +# TODO: convert to Declarative API release_vulnerabilities = Table( "release_vulnerabilities", db.metadata, @@ -47,28 +49,28 @@ class VulnerabilityRecord(db.ModelBase): __tablename__ = "vulnerabilities" - source = Column(String, primary_key=True) - id = Column(String, primary_key=True) + source = mapped_column(String, primary_key=True) + id = mapped_column(String, primary_key=True) # The URL for the vulnerability report at the source # e.g. "https://osv.dev/vulnerability/PYSEC-2021-314" - link = Column(String) + link = mapped_column(String) # Alternative IDs for this vulnerability # e.g. "CVE-2021-12345" - aliases = Column(ARRAY(String)) # type: ignore[var-annotated] + aliases = mapped_column(ARRAY(String)) # type: ignore[var-annotated] # Details about the vulnerability - details = Column(String) + details = mapped_column(String) # A short, plaintext summary of the vulnerability - summary = Column(String) + summary = mapped_column(String) # Events of introduced/fixed versions - fixed_in = Column(ARRAY(String)) # type: ignore[var-annotated] + fixed_in = mapped_column(ARRAY(String)) # type: ignore[var-annotated] # When the vulnerability was withdrawn, if it has been withdrawn. - withdrawn = Column(TIMESTAMP, nullable=True) + withdrawn = mapped_column(TIMESTAMP, nullable=True) releases = orm.relationship( "Release", diff --git a/warehouse/ip_addresses/models.py b/warehouse/ip_addresses/models.py index 88dc552aa247..a4558f62d58e 100644 --- a/warehouse/ip_addresses/models.py +++ b/warehouse/ip_addresses/models.py @@ -14,18 +14,9 @@ import sentry_sdk -from sqlalchemy import ( - Boolean, - CheckConstraint, - Column, - DateTime, - Enum, - Index, - Text, - sql, -) +from sqlalchemy import Boolean, CheckConstraint, DateTime, Enum, Index, Text, sql from sqlalchemy.dialects.postgresql import INET, JSONB -from sqlalchemy.orm import validates +from sqlalchemy.orm import mapped_column, validates from warehouse import db @@ -51,30 +42,30 @@ def __repr__(self) -> str: def __lt__(self, other): return self.id < other.id - ip_address = Column( + ip_address = mapped_column( INET, nullable=False, unique=True, comment="Structured IP Address value" ) - hashed_ip_address = Column( + hashed_ip_address = mapped_column( Text, nullable=True, unique=True, comment="Hash that represents an IP Address" ) - geoip_info = Column( + geoip_info = mapped_column( JSONB, nullable=True, comment="JSON containing GeoIP data associated with an IP Address", ) - is_banned = Column( + is_banned = mapped_column( Boolean, nullable=False, server_default=sql.false(), comment="If True, this IP Address will be marked as banned", ) - ban_reason = Column( # type: ignore[var-annotated] + ban_reason = mapped_column( # type: ignore[var-annotated] Enum(BanReason, values_callable=lambda x: [e.value for e in x]), nullable=True, comment="Reason for banning, must be in the BanReason enumeration", ) - ban_date = Column( + ban_date = mapped_column( DateTime, nullable=True, comment="Date that IP Address was last marked as banned", diff --git a/warehouse/macaroons/models.py b/warehouse/macaroons/models.py index 1d078f381cd7..5a78a473c15e 100644 --- a/warehouse/macaroons/models.py +++ b/warehouse/macaroons/models.py @@ -14,7 +14,6 @@ from sqlalchemy import ( CheckConstraint, - Column, DateTime, ForeignKey, LargeBinary, @@ -24,6 +23,7 @@ sql, ) from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import mapped_column from warehouse import db @@ -51,27 +51,29 @@ class Macaroon(db.Model): # * In the project case, a Macaroon does *not* have an explicit associated # project. Instead, depending on how its used (its request context), # it identifies one of the projects scoped in its caveats. - user_id = Column( + user_id = mapped_column( UUID(as_uuid=True), ForeignKey("users.id"), nullable=True, index=True ) - oidc_publisher_id = Column( + oidc_publisher_id = mapped_column( UUID(as_uuid=True), ForeignKey("oidc_publishers.id"), nullable=True, index=True ) # Store some information about the Macaroon to give users some mechanism # to differentiate between them. - description = Column(String, nullable=False) - created = Column(DateTime, nullable=False, server_default=sql.func.now()) - last_used = Column(DateTime, nullable=True) + description = mapped_column(String, nullable=False) + created = mapped_column(DateTime, nullable=False, server_default=sql.func.now()) + last_used = mapped_column(DateTime, nullable=True) # Human-readable "permissions" for this macaroon, corresponding to the # body of the permissions ("V1") caveat. - permissions_caveat = Column(JSONB, nullable=False, server_default=sql.text("'{}'")) + permissions_caveat = mapped_column( + JSONB, nullable=False, server_default=sql.text("'{}'") + ) # Additional state associated with this macaroon. # For OIDC publisher-issued macaroons, this will contain a subset of OIDC claims. - additional = Column(JSONB, nullable=True) + additional = mapped_column(JSONB, nullable=True) # It might be better to move this default into the database, that way we # make it less likely that something does it incorrectly (since the @@ -80,7 +82,7 @@ class Macaroon(db.Model): # instead of urandom. This is less than optimal, and we would generally # prefer to just always use urandom. Thus we'll do this ourselves here # in our application. - key = Column(LargeBinary, nullable=False, default=_generate_key) + key = mapped_column(LargeBinary, nullable=False, default=_generate_key) # Intentionally not using a back references here, since we express # relationships in terms of the "other" side of the relationship. diff --git a/warehouse/oidc/models/_core.py b/warehouse/oidc/models/_core.py index b6814a1efe62..47fd8a061113 100644 --- a/warehouse/oidc/models/_core.py +++ b/warehouse/oidc/models/_core.py @@ -17,8 +17,9 @@ import sentry_sdk -from sqlalchemy import Column, ForeignKey, String, orm +from sqlalchemy import ForeignKey, String, orm from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import mapped_column from warehouse import db from warehouse.macaroons.models import Macaroon @@ -62,13 +63,13 @@ def wrapper(ground_truth: C, signed_claim: C, all_signed_claims: SignedClaims): class OIDCPublisherProjectAssociation(db.Model): __tablename__ = "oidc_publisher_project_association" - oidc_publisher_id = Column( + oidc_publisher_id = mapped_column( UUID(as_uuid=True), ForeignKey("oidc_publishers.id"), nullable=False, primary_key=True, ) - project_id = Column( + project_id = mapped_column( UUID(as_uuid=True), ForeignKey("projects.id"), nullable=False, primary_key=True ) @@ -82,7 +83,7 @@ class OIDCPublisherMixin: # Each hierarchy of OIDC publishers (both `OIDCPublisher` and # `PendingOIDCPublisher`) use a `discriminator` column for model # polymorphism, but the two are not mutually polymorphic at the DB level. - discriminator = Column(String) + discriminator = mapped_column(String) # A map of claim names to "check" functions, each of which # has the signature `check(ground-truth, signed-claim, all-signed-claims) -> bool`. @@ -244,8 +245,8 @@ class PendingOIDCPublisher(OIDCPublisherMixin, db.Model): __tablename__ = "pending_oidc_publishers" - project_name = Column(String, nullable=False) - added_by_id = Column( + project_name = mapped_column(String, nullable=False) + added_by_id = mapped_column( UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, index=True ) diff --git a/warehouse/oidc/models/github.py b/warehouse/oidc/models/github.py index 7e5cb7456ef0..4ed368d1e62e 100644 --- a/warehouse/oidc/models/github.py +++ b/warehouse/oidc/models/github.py @@ -10,9 +10,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from sqlalchemy import Column, ForeignKey, String, UniqueConstraint +from sqlalchemy import ForeignKey, String, UniqueConstraint from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import Query +from sqlalchemy.orm import Query, mapped_column from sqlalchemy.sql.expression import func, literal from warehouse.oidc.interfaces import SignedClaims @@ -86,11 +86,11 @@ class GitHubPublisherMixin: Common functionality for both pending and concrete GitHub OIDC publishers. """ - repository_name = Column(String, nullable=False) - repository_owner = Column(String, nullable=False) - repository_owner_id = Column(String, nullable=False) - workflow_filename = Column(String, nullable=False) - environment = Column(String, nullable=True) + repository_name = mapped_column(String, nullable=False) + repository_owner = mapped_column(String, nullable=False) + repository_owner_id = mapped_column(String, nullable=False) + workflow_filename = mapped_column(String, nullable=False) + environment = mapped_column(String, nullable=True) __required_verifiable_claims__ = { "sub": _check_sub, @@ -224,7 +224,9 @@ class GitHubPublisher(GitHubPublisherMixin, OIDCPublisher): ), ) - id = Column(UUID(as_uuid=True), ForeignKey(OIDCPublisher.id), primary_key=True) + id = mapped_column( + UUID(as_uuid=True), ForeignKey(OIDCPublisher.id), primary_key=True + ) class PendingGitHubPublisher(GitHubPublisherMixin, PendingOIDCPublisher): @@ -240,7 +242,7 @@ class PendingGitHubPublisher(GitHubPublisherMixin, PendingOIDCPublisher): ), ) - id = Column( + id = mapped_column( UUID(as_uuid=True), ForeignKey(PendingOIDCPublisher.id), primary_key=True ) diff --git a/warehouse/oidc/models/google.py b/warehouse/oidc/models/google.py index d7c86dc3b12c..f258a7c88a82 100644 --- a/warehouse/oidc/models/google.py +++ b/warehouse/oidc/models/google.py @@ -12,9 +12,9 @@ from typing import Any -from sqlalchemy import Column, ForeignKey, String, UniqueConstraint +from sqlalchemy import ForeignKey, String, UniqueConstraint from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import Query +from sqlalchemy.orm import Query, mapped_column from warehouse.oidc.interfaces import SignedClaims from warehouse.oidc.models._core import ( @@ -48,8 +48,8 @@ class GooglePublisherMixin: providers. """ - email = Column(String, nullable=False) - sub = Column(String, nullable=True) + email = mapped_column(String, nullable=False) + sub = mapped_column(String, nullable=True) __required_verifiable_claims__: dict[str, CheckClaimCallable[Any]] = { "email": check_claim_binary(str.__eq__), @@ -105,7 +105,9 @@ class GooglePublisher(GooglePublisherMixin, OIDCPublisher): ), ) - id = Column(UUID(as_uuid=True), ForeignKey(OIDCPublisher.id), primary_key=True) + id = mapped_column( + UUID(as_uuid=True), ForeignKey(OIDCPublisher.id), primary_key=True + ) class PendingGooglePublisher(GooglePublisherMixin, PendingOIDCPublisher): @@ -119,7 +121,7 @@ class PendingGooglePublisher(GooglePublisherMixin, PendingOIDCPublisher): ), ) - id = Column( + id = mapped_column( UUID(as_uuid=True), ForeignKey(PendingOIDCPublisher.id), primary_key=True ) diff --git a/warehouse/organizations/models.py b/warehouse/organizations/models.py index 09e68321b13a..b0dc4daf21b3 100644 --- a/warehouse/organizations/models.py +++ b/warehouse/organizations/models.py @@ -19,7 +19,6 @@ from sqlalchemy import ( Boolean, CheckConstraint, - Column, DateTime, Enum, ForeignKey, @@ -32,7 +31,7 @@ ) from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.exc import NoResultFound -from sqlalchemy.orm import declared_attr +from sqlalchemy.orm import declared_attr, mapped_column from warehouse import db from warehouse.accounts.models import User @@ -64,14 +63,14 @@ class OrganizationRole(db.Model): __repr__ = make_repr("role_name") - role_name = Column( # type: ignore[var-annotated] + role_name = mapped_column( Enum(OrganizationRoleType, values_callable=lambda x: [e.value for e in x]), nullable=False, ) - user_id = Column( # type: ignore[var-annotated] + user_id = mapped_column( ForeignKey("users.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False ) - organization_id = Column( # type: ignore[var-annotated] + organization_id = mapped_column( ForeignKey("organizations.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, ) @@ -94,11 +93,11 @@ class OrganizationProject(db.Model): __repr__ = make_repr("project_id", "organization_id") - organization_id = Column( # type: ignore[var-annotated] + organization_id = mapped_column( ForeignKey("organizations.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, ) - project_id = Column( # type: ignore[var-annotated] + project_id = mapped_column( ForeignKey("projects.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, ) @@ -125,11 +124,11 @@ class OrganizationStripeSubscription(db.Model): __repr__ = make_repr("organization_id", "subscription_id") - organization_id = Column( # type: ignore[var-annotated] + organization_id = mapped_column( ForeignKey("organizations.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, ) - subscription_id = Column( # type: ignore[var-annotated] + subscription_id = mapped_column( ForeignKey("stripe_subscriptions.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, ) @@ -154,11 +153,11 @@ class OrganizationStripeCustomer(db.Model): __repr__ = make_repr("organization_id", "stripe_customer_id") - organization_id = Column( # type: ignore[var-annotated] + organization_id = mapped_column( ForeignKey("organizations.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, ) - stripe_customer_id = Column( # type: ignore[var-annotated] + stripe_customer_id = mapped_column( ForeignKey("stripe_customers.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, ) @@ -229,28 +228,30 @@ def __table_args__(cls): # noqa: N805 ), ) - name = Column(Text, nullable=False, comment="The account name used in URLS") + name = mapped_column(Text, nullable=False, comment="The account name used in URLS") @declared_attr def normalized_name(cls): # noqa: N805 return orm.column_property(func.normalize_pep426_name(cls.name)) - display_name = Column(Text, nullable=False, comment="Display name used in UI") - orgtype = Column( # type: ignore[var-annotated] + display_name = mapped_column( + Text, nullable=False, comment="Display name used in UI" + ) + orgtype = mapped_column( Enum(OrganizationType, values_callable=lambda x: [e.value for e in x]), nullable=False, comment="What type of organization such as Community or Company", ) - link_url = Column( + link_url = mapped_column( Text, nullable=False, comment="External URL associated with the organization" ) - description = Column( + description = mapped_column( Text, nullable=False, comment="Description of the business or project the organization represents", ) - is_approved = Column( + is_approved = mapped_column( Boolean, comment="Status of administrator approval of the request" ) @@ -262,20 +263,20 @@ class Organization(OrganizationMixin, HasEvents, db.Model): __repr__ = make_repr("name") - is_active = Column( + is_active = mapped_column( Boolean, nullable=False, server_default=sql.false(), comment="When True, the organization is active and all features are available.", ) - created = Column( + created = mapped_column( DateTime(timezone=False), nullable=False, server_default=sql.func.now(), index=True, comment="Datetime the organization was created.", ) - date_approved = Column( + date_approved = mapped_column( DateTime(timezone=False), nullable=True, onupdate=func.now(), @@ -451,7 +452,7 @@ class OrganizationApplication(OrganizationMixin, db.Model): __tablename__ = "organization_applications" __repr__ = make_repr("name") - submitted_by_id = Column( + submitted_by_id = mapped_column( UUID(as_uuid=True), ForeignKey( User.id, @@ -462,14 +463,14 @@ class OrganizationApplication(OrganizationMixin, db.Model): nullable=False, comment="ID of the User which submitted the request", ) - submitted = Column( + submitted = mapped_column( DateTime(timezone=False), nullable=False, server_default=sql.func.now(), index=True, comment="Datetime the request was submitted", ) - organization_id = Column( + organization_id = mapped_column( UUID(as_uuid=True), ForeignKey( Organization.id, @@ -503,8 +504,8 @@ class OrganizationNameCatalog(db.Model): __repr__ = make_repr("normalized_name", "organization_id") - normalized_name = Column(Text, nullable=False, index=True) - organization_id = Column(UUID(as_uuid=True), nullable=True, index=True) + normalized_name = mapped_column(Text, nullable=False, index=True) + organization_id = mapped_column(UUID(as_uuid=True), nullable=True, index=True) class OrganizationInvitationStatus(enum.Enum): @@ -525,19 +526,19 @@ class OrganizationInvitation(db.Model): __repr__ = make_repr("invite_status", "user", "organization") - invite_status = Column( # type: ignore[var-annotated] + invite_status = mapped_column( Enum( OrganizationInvitationStatus, values_callable=lambda x: [e.value for e in x] ), nullable=False, ) - token = Column(Text, nullable=False) - user_id = Column( # type: ignore[var-annotated] + token = mapped_column(Text, nullable=False) + user_id = mapped_column( ForeignKey("users.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, index=True, ) - organization_id = Column( # type: ignore[var-annotated] + organization_id = mapped_column( ForeignKey("organizations.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, index=True, @@ -565,14 +566,14 @@ class TeamRole(db.Model): __repr__ = make_repr("role_name", "team", "user") - role_name = Column( # type: ignore[var-annotated] + role_name = mapped_column( Enum(TeamRoleType, values_callable=lambda x: [e.value for e in x]), nullable=False, ) - user_id = Column( # type: ignore[var-annotated] + user_id = mapped_column( ForeignKey("users.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False ) - team_id = Column( # type: ignore[var-annotated] + team_id = mapped_column( ForeignKey("teams.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, ) @@ -600,15 +601,15 @@ class TeamProjectRole(db.Model): __repr__ = make_repr("role_name", "team", "project") - role_name = Column( # type: ignore[var-annotated] + role_name = mapped_column( Enum(TeamProjectRoleType, values_callable=lambda x: [e.value for e in x]), nullable=False, ) - project_id = Column( # type: ignore[var-annotated] + project_id = mapped_column( ForeignKey("projects.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, ) - team_id = Column( # type: ignore[var-annotated] + team_id = mapped_column( ForeignKey("teams.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, ) @@ -658,13 +659,13 @@ class Team(HasEvents, db.Model): __repr__ = make_repr("name", "organization") - name = Column(Text, nullable=False) + name = mapped_column(Text, nullable=False) normalized_name = orm.column_property(func.normalize_team_name(name)) # type: ignore[var-annotated] # noqa: E501 - organization_id = Column( # type: ignore[var-annotated] + organization_id = mapped_column( ForeignKey("organizations.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, ) - created = Column( + created = mapped_column( DateTime(timezone=False), nullable=False, server_default=sql.func.now(), diff --git a/warehouse/packaging/models.py b/warehouse/packaging/models.py index 47737c90a24f..b6a58d981747 100644 --- a/warehouse/packaging/models.py +++ b/warehouse/packaging/models.py @@ -43,7 +43,7 @@ from sqlalchemy.exc import MultipleResultsFound, NoResultFound from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import attribute_keyed_dict, declared_attr, validates +from sqlalchemy.orm import attribute_keyed_dict, declared_attr, mapped_column, validates from warehouse import db from warehouse.accounts.models import User @@ -72,11 +72,11 @@ class Role(db.Model): __repr__ = make_repr("role_name") - role_name = Column(Text, nullable=False) - user_id = Column( # type: ignore[var-annotated] + role_name = mapped_column(Text, nullable=False) + user_id = mapped_column( ForeignKey("users.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False ) - project_id = Column( # type: ignore[var-annotated] + project_id = mapped_column( ForeignKey("projects.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, ) @@ -101,17 +101,17 @@ class RoleInvitation(db.Model): __repr__ = make_repr("invite_status", "user", "project") - invite_status = Column( # type: ignore[var-annotated] + invite_status = mapped_column( Enum(RoleInvitationStatus, values_callable=lambda x: [e.value for e in x]), nullable=False, ) - token = Column(Text, nullable=False) - user_id = Column( # type: ignore[var-annotated] + token = mapped_column(Text, nullable=False) + user_id = mapped_column( ForeignKey("users.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, index=True, ) - project_id = Column( # type: ignore[var-annotated] + project_id = mapped_column( ForeignKey("projects.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, index=True, @@ -146,9 +146,13 @@ def __contains__(self, project): class TwoFactorRequireable: # Project owner requires 2FA for this project - owners_require_2fa = Column(Boolean, nullable=False, server_default=sql.false()) + owners_require_2fa = mapped_column( + Boolean, nullable=False, server_default=sql.false() + ) # PyPI requires 2FA for this project - pypi_mandates_2fa = Column(Boolean, nullable=False, server_default=sql.false()) + pypi_mandates_2fa = mapped_column( + Boolean, nullable=False, server_default=sql.false() + ) @hybrid_property def two_factor_required(self): @@ -159,25 +163,25 @@ class Project(SitemapMixin, TwoFactorRequireable, HasEvents, db.Model): __tablename__ = "projects" __repr__ = make_repr("name") - name = Column(Text, nullable=False) - normalized_name = Column( + name = mapped_column(Text, nullable=False) + normalized_name = mapped_column( Text, nullable=False, unique=True, server_default=FetchedValue(), server_onupdate=FetchedValue(), ) - created = Column( + created = mapped_column( DateTime(timezone=False), nullable=True, server_default=sql.func.now(), index=True, ) - has_docs = Column(Boolean) - upload_limit = Column(Integer, nullable=True) - total_size_limit = Column(BigInteger, nullable=True) - last_serial = Column(Integer, nullable=False, server_default=sql.text("0")) - total_size = Column(BigInteger, server_default=sql.text("0")) + has_docs = mapped_column(Boolean) + upload_limit = mapped_column(Integer, nullable=True) + total_size_limit = mapped_column(BigInteger, nullable=True) + last_serial = mapped_column(Integer, nullable=False, server_default=sql.text("0")) + total_size = mapped_column(BigInteger, server_default=sql.text("0")) organization = orm.relationship( Organization, @@ -368,12 +372,12 @@ class Dependency(db.Model): ) __repr__ = make_repr("release", "kind", "specifier") - release_id = Column( # type: ignore[var-annotated] + release_id = mapped_column( ForeignKey("releases.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, ) - kind = Column(Integer) - specifier = Column(Text) + kind = mapped_column(Integer) + specifier = mapped_column(Text) def _dependency_relation(kind): @@ -389,10 +393,10 @@ def _dependency_relation(kind): class Description(db.Model): __tablename__ = "release_descriptions" - content_type = Column(Text) - raw = Column(Text, nullable=False) - html = Column(Text, nullable=False) - rendered_by = Column(Text, nullable=False) + content_type = mapped_column(Text) + raw = mapped_column(Text, nullable=False) + html = mapped_column(Text, nullable=False) + rendered_by = mapped_column(Text, nullable=False) class ReleaseURL(db.Model): @@ -406,14 +410,14 @@ class ReleaseURL(db.Model): ) __repr__ = make_repr("name", "url") - release_id = Column( # type: ignore[var-annotated] + release_id = mapped_column( ForeignKey("releases.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, index=True, ) - name = Column(String(32), nullable=False) - url = Column(Text, nullable=False) + name = mapped_column(String(32), nullable=False) + url = mapped_column(Text, nullable=False) class Release(db.Model): @@ -433,30 +437,30 @@ def __table_args__(cls): # noqa __parent__ = dotted_navigator("project") __name__ = dotted_navigator("version") - project_id = Column( # type: ignore[var-annotated] + project_id = mapped_column( ForeignKey("projects.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, ) - version = Column(Text, nullable=False) - canonical_version = Column(Text, nullable=False) - is_prerelease = Column(Boolean, nullable=False, server_default=sql.false()) - author = Column(Text) - author_email = Column(Text) - maintainer = Column(Text) - maintainer_email = Column(Text) - home_page = Column(Text) - license = Column(Text) - summary = Column(Text) - keywords = Column(Text) - platform = Column(Text) - download_url = Column(Text) - _pypi_ordering = Column(Integer) - requires_python = Column(Text) - created = Column( + version = mapped_column(Text, nullable=False) + canonical_version = mapped_column(Text, nullable=False) + is_prerelease = mapped_column(Boolean, nullable=False, server_default=sql.false()) + author = mapped_column(Text) + author_email = mapped_column(Text) + maintainer = mapped_column(Text) + maintainer_email = mapped_column(Text) + home_page = mapped_column(Text) + license = mapped_column(Text) + summary = mapped_column(Text) + keywords = mapped_column(Text) + platform = mapped_column(Text) + download_url = mapped_column(Text) + _pypi_ordering = mapped_column(Integer) + requires_python = mapped_column(Text) + created = mapped_column( DateTime(timezone=False), nullable=False, server_default=sql.func.now() ) - description_id = Column( # type: ignore[var-annotated] + description_id = mapped_column( ForeignKey("release_descriptions.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, index=True, @@ -473,9 +477,9 @@ def __table_args__(cls): # noqa ), ) - yanked = Column(Boolean, nullable=False, server_default=sql.false()) + yanked = mapped_column(Boolean, nullable=False, server_default=sql.false()) - yanked_reason = Column(Text, nullable=False, server_default="") + yanked_reason = mapped_column(Text, nullable=False, server_default="") _classifiers = orm.relationship( Classifier, @@ -544,13 +548,13 @@ def __table_args__(cls): # noqa _requires_external = _dependency_relation(DependencyKind.requires_external) requires_external = association_proxy("_requires_external", "specifier") - uploader_id = Column( # type: ignore[var-annotated] + uploader_id = mapped_column( ForeignKey("users.id", onupdate="CASCADE", ondelete="SET NULL"), nullable=True, index=True, ) uploader = orm.relationship(User) - uploaded_via = Column(Text) + uploaded_via = mapped_column(Text) @property def urls(self): @@ -642,13 +646,13 @@ def __table_args__(cls): # noqa Index("release_files_cached_idx", "cached"), ) - release_id = Column( # type: ignore[var-annotated] + release_id = mapped_column( ForeignKey("releases.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, ) - python_version = Column(Text) - requires_python = Column(Text) - packagetype = Column( # type: ignore[var-annotated] + python_version = mapped_column(Text) + requires_python = mapped_column(Text) + packagetype = mapped_column( Enum( "bdist_dmg", "bdist_dumb", @@ -660,32 +664,34 @@ def __table_args__(cls): # noqa "sdist", ) ) - comment_text = Column(Text) - filename = Column(Text, unique=True) - path = Column(Text, unique=True, nullable=False) - size = Column(Integer) - md5_digest = Column(Text, unique=True, nullable=False) - sha256_digest = Column(CITEXT, unique=True, nullable=False) - blake2_256_digest = Column(CITEXT, unique=True, nullable=False) - upload_time = Column(DateTime(timezone=False), server_default=func.now()) - uploaded_via = Column(Text) + comment_text = mapped_column(Text) + filename = mapped_column(Text, unique=True) + path = mapped_column(Text, unique=True, nullable=False) + size = mapped_column(Integer) + md5_digest = mapped_column(Text, unique=True, nullable=False) + sha256_digest = mapped_column(CITEXT, unique=True, nullable=False) + blake2_256_digest = mapped_column(CITEXT, unique=True, nullable=False) + upload_time = mapped_column(DateTime(timezone=False), server_default=func.now()) + uploaded_via = mapped_column(Text) # PEP 658 - metadata_file_sha256_digest = Column(CITEXT, nullable=True) - metadata_file_blake2_256_digest = Column(CITEXT, nullable=True) + metadata_file_sha256_digest = mapped_column(CITEXT, nullable=True) + metadata_file_blake2_256_digest = mapped_column(CITEXT, nullable=True) # We need this column to allow us to handle the currently existing "double" # sdists that exist in our database. Eventually we should try to get rid # of all of them and then remove this column. - allow_multiple_sdist = Column(Boolean, nullable=False, server_default=sql.false()) + allow_multiple_sdist = mapped_column( + Boolean, nullable=False, server_default=sql.false() + ) - cached = Column( + cached = mapped_column( Boolean, comment="If True, the object has been populated to our cache bucket.", nullable=False, server_default=sql.false(), ) - archived = Column( + archived = mapped_column( Boolean, comment="If True, the object has been archived to our archival bucket.", nullable=False, @@ -708,10 +714,11 @@ def validates_requires_python(self, *args, **kwargs): class Filename(db.ModelBase): __tablename__ = "file_registry" - id = Column(Integer, primary_key=True, nullable=False) - filename = Column(Text, unique=True, nullable=False) + id = mapped_column(Integer, primary_key=True, nullable=False) + filename = mapped_column(Text, unique=True, nullable=False) +# TODO: Convert to Declarative API release_classifiers = Table( "release_classifiers", db.metadata, @@ -739,14 +746,14 @@ def __table_args__(cls): # noqa Index("journals_submitted_date_id_idx", cls.submitted_date, cls.id), ) - id = Column(Integer, primary_key=True, nullable=False) - name = Column(Text) - version = Column(Text) - action = Column(Text) - submitted_date = Column( + id = mapped_column(Integer, primary_key=True, nullable=False) + name = mapped_column(Text) + version = mapped_column(Text) + action = mapped_column(Text) + submitted_date = mapped_column( DateTime(timezone=False), nullable=False, server_default=sql.func.now() ) - _submitted_by = Column( + _submitted_by = mapped_column( "submitted_by", CITEXT, ForeignKey("users.username", onupdate="CASCADE"), @@ -766,12 +773,12 @@ class ProhibitedProjectName(db.Model): __repr__ = make_repr("name") - created = Column( + created = mapped_column( DateTime(timezone=False), nullable=False, server_default=sql.func.now() ) - name = Column(Text, unique=True, nullable=False) - _prohibited_by = Column( + name = mapped_column(Text, unique=True, nullable=False) + _prohibited_by = mapped_column( "prohibited_by", UUID(as_uuid=True), ForeignKey("users.id"), index=True ) prohibited_by = orm.relationship(User) - comment = Column(Text, nullable=False, server_default="") + comment = mapped_column(Text, nullable=False, server_default="") diff --git a/warehouse/sitemap/models.py b/warehouse/sitemap/models.py index 4d56588e5634..084eb179f2b2 100644 --- a/warehouse/sitemap/models.py +++ b/warehouse/sitemap/models.py @@ -10,11 +10,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from sqlalchemy import Column, FetchedValue, Text +from sqlalchemy import FetchedValue, Text +from sqlalchemy.orm import mapped_column class SitemapMixin: - sitemap_bucket = Column( + sitemap_bucket = mapped_column( Text, nullable=False, server_default=FetchedValue(), diff --git a/warehouse/sponsors/models.py b/warehouse/sponsors/models.py index 1a8e06b9df62..f4076fc1d5ec 100644 --- a/warehouse/sponsors/models.py +++ b/warehouse/sponsors/models.py @@ -10,7 +10,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from sqlalchemy import Boolean, Column, Integer, String, Text +from sqlalchemy import Boolean, Integer, String, Text +from sqlalchemy.orm import mapped_column from warehouse import db from warehouse.utils import readme @@ -21,27 +22,27 @@ class Sponsor(db.Model): __tablename__ = "sponsors" __repr__ = make_repr("name") - name = Column(String, nullable=False) - service = Column(String) - activity_markdown = Column(Text) + name = mapped_column(String, nullable=False) + service = mapped_column(String) + activity_markdown = mapped_column(Text) - link_url = Column(Text, nullable=False) - color_logo_url = Column(Text, nullable=False) - white_logo_url = Column(Text) + link_url = mapped_column(Text, nullable=False) + color_logo_url = mapped_column(Text, nullable=False) + white_logo_url = mapped_column(Text) # control flags - is_active = Column(Boolean, default=False, nullable=False) - footer = Column(Boolean, default=False, nullable=False) - psf_sponsor = Column(Boolean, default=False, nullable=False) - infra_sponsor = Column(Boolean, default=False, nullable=False) - one_time = Column(Boolean, default=False, nullable=False) - sidebar = Column(Boolean, default=False, nullable=False) + is_active = mapped_column(Boolean, default=False, nullable=False) + footer = mapped_column(Boolean, default=False, nullable=False) + psf_sponsor = mapped_column(Boolean, default=False, nullable=False) + infra_sponsor = mapped_column(Boolean, default=False, nullable=False) + one_time = mapped_column(Boolean, default=False, nullable=False) + sidebar = mapped_column(Boolean, default=False, nullable=False) # pythondotorg integration - origin = Column(String, default="manual") - level_name = Column(String) - level_order = Column(Integer, default=0) - slug = Column(String) + origin = mapped_column(String, default="manual") + level_name = mapped_column(String) + level_order = mapped_column(Integer, default=0) + slug = mapped_column(String) @property def color_logo_img(self): diff --git a/warehouse/subscriptions/models.py b/warehouse/subscriptions/models.py index fa7c859c6f92..46d0a0ae3632 100644 --- a/warehouse/subscriptions/models.py +++ b/warehouse/subscriptions/models.py @@ -14,7 +14,6 @@ from sqlalchemy import ( Boolean, - Column, Enum, ForeignKey, Index, @@ -25,6 +24,7 @@ sql, ) from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import mapped_column from warehouse import db from warehouse.i18n import localize as _ @@ -64,10 +64,10 @@ class StripeCustomer(db.Model): __repr__ = make_repr("customer_id", "billing_email") - customer_id = Column( + customer_id = mapped_column( Text, nullable=False, unique=True ) # generated by Payment Service Provider - billing_email = Column(Text) + billing_email = mapped_column(Text) organization = orm.relationship( Organization, @@ -93,22 +93,22 @@ class StripeSubscription(db.Model): __repr__ = make_repr("subscription_id", "stripe_customer_id") - stripe_customer_id = Column( + stripe_customer_id = mapped_column( UUID(as_uuid=True), ForeignKey("stripe_customers.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, ) - subscription_id = Column( + subscription_id = mapped_column( Text, nullable=False ) # generated by Payment Service Provider - subscription_price_id = Column( + subscription_price_id = mapped_column( UUID(as_uuid=True), ForeignKey( "stripe_subscription_prices.id", onupdate="CASCADE", ondelete="CASCADE" ), nullable=False, ) - status = Column( # type: ignore[var-annotated] + status = mapped_column( Enum(StripeSubscriptionStatus, values_callable=lambda x: [e.value for e in x]), nullable=False, ) @@ -147,11 +147,15 @@ class StripeSubscriptionProduct(db.Model): __repr__ = make_repr("product_name") - product_id = Column(Text, nullable=True) # generated by Payment Service Provider - product_name = Column(Text, nullable=False) - description = Column(Text, nullable=False) - is_active = Column(Boolean, nullable=False, server_default=sql.true()) - tax_code = Column(Text, nullable=True) # https://stripe.com/docs/tax/tax-categories + product_id = mapped_column( + Text, nullable=True + ) # generated by Payment Service Provider + product_name = mapped_column(Text, nullable=False) + description = mapped_column(Text, nullable=False) + is_active = mapped_column(Boolean, nullable=False, server_default=sql.true()) + tax_code = mapped_column( + Text, nullable=True + ) # https://stripe.com/docs/tax/tax-categories class StripeSubscriptionPrice(db.Model): @@ -159,25 +163,27 @@ class StripeSubscriptionPrice(db.Model): __repr__ = make_repr("price_id", "unit_amount", "recurring") - price_id = Column(Text, nullable=True) # generated by Payment Service Provider - currency = Column(Text, nullable=False) # https://stripe.com/docs/currencies - subscription_product_id = Column( + price_id = mapped_column( + Text, nullable=True + ) # generated by Payment Service Provider + currency = mapped_column(Text, nullable=False) # https://stripe.com/docs/currencies + subscription_product_id = mapped_column( UUID(as_uuid=True), ForeignKey( "stripe_subscription_products.id", onupdate="CASCADE", ondelete="CASCADE" ), nullable=False, ) - unit_amount = Column(Integer, nullable=False) # positive integer in cents - is_active = Column(Boolean, nullable=False, server_default=sql.true()) - recurring = Column( # type: ignore[var-annotated] + unit_amount = mapped_column(Integer, nullable=False) # positive integer in cents + is_active = mapped_column(Boolean, nullable=False, server_default=sql.true()) + recurring = mapped_column( # type: ignore[var-annotated] Enum( StripeSubscriptionPriceInterval, values_callable=lambda x: [e.value for e in x], ), nullable=False, ) - tax_behavior = Column( + tax_behavior = mapped_column( Text, nullable=True ) # TODO: Enum? inclusive, exclusive, unspecified @@ -191,22 +197,22 @@ class StripeSubscriptionItem(db.Model): "subscription_item_id", "subscription_id", "subscription_price_id", "quantity" ) - subscription_item_id = Column( + subscription_item_id = mapped_column( Text, nullable=True ) # generated by Payment Service Provider - subscription_id = Column( + subscription_id = mapped_column( UUID(as_uuid=True), ForeignKey("stripe_subscriptions.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, ) - subscription_price_id = Column( + subscription_price_id = mapped_column( UUID(as_uuid=True), ForeignKey( "stripe_subscription_prices.id", onupdate="CASCADE", ondelete="CASCADE" ), nullable=False, ) - quantity = Column(Integer, nullable=False) # positive integer or zero + quantity = mapped_column(Integer, nullable=False) # positive integer or zero subscription = orm.relationship( "StripeSubscription", lazy=False, back_populates="subscription_item" diff --git a/warehouse/utils/row_counter.py b/warehouse/utils/row_counter.py index 6791acec2634..97af860a9db9 100644 --- a/warehouse/utils/row_counter.py +++ b/warehouse/utils/row_counter.py @@ -10,7 +10,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from sqlalchemy import BigInteger, Column, Text, sql +from sqlalchemy import BigInteger, Text, sql +from sqlalchemy.orm import mapped_column from warehouse import db @@ -18,5 +19,5 @@ class RowCount(db.Model): __tablename__ = "row_counts" - table_name = Column(Text, nullable=False, unique=True) - count = Column(BigInteger, nullable=False, server_default=sql.text("0")) + table_name = mapped_column(Text, nullable=False, unique=True) + count = mapped_column(BigInteger, nullable=False, server_default=sql.text("0"))