Skip to content

Commit

Permalink
Add support for sqlalchemy collection_class property (#625)
Browse files Browse the repository at this point in the history
* Add support for sqlalchemy collection_class property

* fix ci
  • Loading branch information
jowilf authored Feb 2, 2025
1 parent 99a1425 commit 33abffb
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 8 deletions.
8 changes: 7 additions & 1 deletion starlette_admin/contrib/sqla/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,13 @@ def convert_fields_list(
):
converted_fields.append(HasOne(attr.key, identity=identity))
else:
converted_fields.append(HasMany(attr.key, identity=identity))
converted_fields.append(
HasMany(
attr.key,
identity=identity,
collection_class=attr.collection_class or list,
)
)
elif isinstance(attr, ColumnProperty):
assert (
len(attr.columns) == 1
Expand Down
10 changes: 4 additions & 6 deletions starlette_admin/contrib/sqla/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from sqlalchemy.sql import Select
from starlette.requests import Request
from starlette.responses import Response
from starlette_admin import BaseField
from starlette_admin import BaseField, HasMany
from starlette_admin._types import RequestAction
from starlette_admin.contrib.sqla.converters import (
BaseSQLAModelConverter,
Expand Down Expand Up @@ -521,12 +521,10 @@ async def _arrange_data(
for field in self.get_fields_list(request, request.state.action):
if isinstance(field, RelationField) and data[field.name] is not None:
foreign_model = self._find_foreign_model(field.identity) # type: ignore
if not field.multiple:
arranged_data[field.name] = await foreign_model.find_by_pk(
request, data[field.name]
)
if isinstance(field, HasMany):
arranged_data[field.name] = field.collection_class(await foreign_model.find_by_pks(request, data[field.name])) # type: ignore[call-arg]
else:
arranged_data[field.name] = await foreign_model.find_by_pks(
arranged_data[field.name] = await foreign_model.find_by_pk(
request, data[field.name]
)
else:
Expand Down
2 changes: 2 additions & 0 deletions starlette_admin/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import (
Any,
Callable,
Collection,
Dict,
List,
Optional,
Expand Down Expand Up @@ -1094,6 +1095,7 @@ class HasMany(RelationField):
"""A field representing a "has-many" relationship between two models."""

multiple: bool = True
collection_class: Union[Type[Collection[Any]], Callable[[], Collection[Any]]] = list


@dataclass(init=False)
Expand Down
2 changes: 1 addition & 1 deletion tests/sqla/test_sqla_and_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class User(Base, IDMixin):

name = Column(String(100))

todos = relationship("Todo", back_populates="user")
todos = relationship("Todo", back_populates="user", collection_class=set)


class Todo(Base, IDMixin):
Expand Down

0 comments on commit 33abffb

Please sign in to comment.