Skip to content

✨ Add support for hybrid_property #482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sqlalchemy import Boolean, Column, Date, DateTime
from sqlalchemy import Enum as sa_Enum
from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
Expand Down Expand Up @@ -207,6 +208,7 @@ def Relationship(
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
__sqlmodel_relationships__: Dict[str, RelationshipInfo]
__sqlalchemy_constructs__: Dict[str, Any]
__config__: Type[BaseConfig]
__fields__: Dict[str, ModelField]

Expand All @@ -232,6 +234,7 @@ def __new__(
**kwargs: Any,
) -> Any:
relationships: Dict[str, RelationshipInfo] = {}
sqlalchemy_constructs = {}
dict_for_pydantic = {}
original_annotations = resolve_annotations(
class_dict.get("__annotations__", {}), class_dict.get("__module__", None)
Expand All @@ -241,6 +244,8 @@ def __new__(
for k, v in class_dict.items():
if isinstance(v, RelationshipInfo):
relationships[k] = v
elif isinstance(v, hybrid_property):
sqlalchemy_constructs[k] = v
else:
dict_for_pydantic[k] = v
for k, v in original_annotations.items():
Expand All @@ -253,6 +258,7 @@ def __new__(
"__weakref__": None,
"__sqlmodel_relationships__": relationships,
"__annotations__": pydantic_annotations,
"__sqlalchemy_constructs__": sqlalchemy_constructs,
}
# Duplicate logic from Pydantic to filter config kwargs because if they are
# passed directly including the registry Pydantic will pass them over to the
Expand All @@ -276,6 +282,11 @@ def __new__(
**new_cls.__annotations__,
}

# We did not provide the sqlalchemy constructs to Pydantic's new function above
# so that they wouldn't be modified. Instead we set them directly to the class below:
for k, v in sqlalchemy_constructs.items():
setattr(new_cls, k, v)

def get_config(name: str) -> Any:
config_class_value = getattr(new_cls.__config__, name, Undefined)
if config_class_value is not Undefined:
Expand All @@ -290,8 +301,9 @@ def get_config(name: str) -> Any:
# If it was passed by kwargs, ensure it's also set in config
new_cls.__config__.table = config_table
for k, v in new_cls.__fields__.items():
col = get_column_from_field(v)
setattr(new_cls, k, col)
if k in sqlalchemy_constructs:
continue
setattr(new_cls, k, get_column_from_field(v))
# Set a config flag to tell FastAPI that this should be read with a field
# in orm_mode instead of preemptively converting it to a dict.
# This could be done by reading new_cls.__config__.table in FastAPI, but
Expand Down Expand Up @@ -326,6 +338,8 @@ def __init__(
if getattr(cls.__config__, "table", False) and not base_is_table:
dict_used = dict_.copy()
for field_name, field_value in cls.__fields__.items():
if field_name in cls.__sqlalchemy_constructs__:
continue
dict_used[field_name] = get_column_from_field(field_value)
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
if rel_info.sa_relationship:
Expand Down
9 changes: 8 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest
from pydantic import BaseModel
from sqlmodel import SQLModel
from sqlmodel import SQLModel, create_engine
from sqlmodel.main import default_registry

top_level_path = Path(__file__).resolve().parent.parent
Expand All @@ -23,6 +23,13 @@ def clear_sqlmodel():
default_registry.dispose()


@pytest.fixture()
def in_memory_engine(clear_sqlmodel):
engine = create_engine("sqlite:///memory")
yield engine
SQLModel.metadata.drop_all(engine, checkfirst=True)


@pytest.fixture()
def cov_tmp_path(tmp_path: Path):
yield tmp_path
Expand Down
41 changes: 41 additions & 0 deletions tests/test_sqlalchemy_properties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Optional

from sqlalchemy import func
from sqlalchemy.ext.hybrid import hybrid_property
from sqlmodel import Field, Session, SQLModel, select


def test_hybrid_property(in_memory_engine):
class Interval(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
length: float

@hybrid_property
def radius(self) -> float:
return abs(self.length) / 2

@radius.expression
def radius(cls) -> float:
return func.abs(cls.length) / 2

class Config:
arbitrary_types_allowed = True

SQLModel.metadata.create_all(in_memory_engine)
session = Session(in_memory_engine)

interval = Interval(length=-2)
assert interval.radius == 1

session.add(interval)
session.commit()
interval_2 = session.exec(select(Interval)).all()[0]
assert interval_2.radius == 1

interval_3 = session.exec(select(Interval).where(Interval.radius == 1)).all()[0]
assert interval_3.radius == 1

intervals = session.exec(select(Interval).where(Interval.radius > 1)).all()
assert len(intervals) == 0

assert session.exec(select(Interval.radius + 1)).all()[0] == 2.0