Skip to content

Commit

Permalink
Merge pull request #62 from collerek/exclude_default
Browse files Browse the repository at this point in the history
fix for issue-60
  • Loading branch information
collerek authored Dec 2, 2020
2 parents cbd793c + 380bb29 commit 40254b9
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 4 deletions.
7 changes: 6 additions & 1 deletion docs/releases.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# 0.6.1

* Explicitly set None to excluded nullable fields to avoid pydantic setting a default value (fix [#60][#60]).

# 0.6.0

* **Breaking:** calling instance.load() when the instance row was deleted from db now raises ormar.NoMatch instead of ValueError
Expand Down Expand Up @@ -155,4 +159,5 @@ Add queryset level methods
* Added ManyToMany field and support for many to many relations


[#19]: https://github.com/collerek/ormar/issues/19
[#19]: https://github.com/collerek/ormar/issues/19
[#60]: https://github.com/collerek/ormar/issues/60
2 changes: 1 addition & 1 deletion ormar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __repr__(self) -> str:

Undefined = UndefinedType()

__version__ = "0.6.0"
__version__ = "0.6.1"
__all__ = [
"Integer",
"BigInteger",
Expand Down
4 changes: 3 additions & 1 deletion ormar/models/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ def populate_pydantic_default_values(attrs: Dict) -> Tuple[Dict, Dict]:
field.name = field_name
attrs = populate_default_pydantic_field_value(field, field_name, attrs)
model_fields[field_name] = field
attrs["__annotations__"][field_name] = field.__type__
attrs["__annotations__"][field_name] = (
field.__type__ if not field.nullable else Optional[field.__type__]
)
return attrs, model_fields


Expand Down
3 changes: 3 additions & 0 deletions ormar/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def from_row( # noqa CCR001

instance: Optional[T] = None
if item.get(cls.Meta.pkname, None) is not None:
item["__excluded__"] = cls.get_names_to_exclude(
fields=fields, exclude_fields=exclude_fields
)
instance = cls(**item)
instance.set_save_status(True)
else:
Expand Down
30 changes: 30 additions & 0 deletions ormar/models/modelproxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,36 @@ def extract_db_own_fields(cls) -> Set:
}
return self_fields

@classmethod
def get_names_to_exclude(
cls,
fields: Optional[Union[Dict, Set]] = None,
exclude_fields: Optional[Union[Dict, Set]] = None,
) -> Set:
fields_names = cls.extract_db_own_fields()
if fields and fields is not Ellipsis:
fields_to_keep = {name for name in fields if name in fields_names}
else:
fields_to_keep = fields_names

fields_to_exclude = fields_names - fields_to_keep

if isinstance(exclude_fields, Set):
fields_to_exclude = fields_to_exclude.union(
{name for name in exclude_fields if name in fields_names}
)
elif isinstance(exclude_fields, Dict):
new_to_exclude = {
name
for name in exclude_fields
if name in fields_names and exclude_fields[name] is Ellipsis
}
fields_to_exclude = fields_to_exclude.union(new_to_exclude)

fields_to_exclude = fields_to_exclude - {cls.Meta.pkname}

return fields_to_exclude

@classmethod
def substitute_models_with_pks(cls, model_dict: Dict) -> Dict: # noqa CCR001
for field in cls.extract_related_names():
Expand Down
9 changes: 8 additions & 1 deletion ormar/models/newbasemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore
)

pk_only = kwargs.pop("__pk_only__", False)
excluded: Set[str] = kwargs.pop("__excluded__", set())

if "pk" in kwargs:
kwargs[self.Meta.pkname] = kwargs.pop("pk")
# build the models to set them and validate but don't register
try:
new_kwargs = {
new_kwargs: Dict[str, Any] = {
k: self._convert_json(
k,
self.Meta.model_fields[k].expand_relationship(
Expand All @@ -111,6 +113,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore
f"Unknown field '{e.args[0]}' for model {self.get_name(lower=False)}"
)

# explicitly set None to excluded fields with default
# as pydantic populates them with default
for field_to_nullify in excluded:
new_kwargs[field_to_nullify] = None

values, fields_set, validation_error = pydantic.validate_model(
self, new_kwargs # type: ignore
)
Expand Down
113 changes: 113 additions & 0 deletions tests/test_excluding_fields_with_default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import random
from typing import Optional

import databases
import pytest
import sqlalchemy

import ormar
from tests.settings import DATABASE_URL

database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()


def get_position() -> int:
return random.randint(1, 10)


class Album(ormar.Model):
class Meta:
tablename = "albums"
metadata = metadata
database = database

id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
is_best_seller: bool = ormar.Boolean(default=False, nullable=True)


class Track(ormar.Model):
class Meta:
tablename = "tracks"
metadata = metadata
database = database

id: int = ormar.Integer(primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album)
title: str = ormar.String(max_length=100)
position: int = ormar.Integer(default=get_position)
play_count: int = ormar.Integer(nullable=True, default=0)


@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)


@pytest.mark.asyncio
async def test_excluding_field_with_default():
async with database:
async with database.transaction(force_rollback=True):
album = await Album.objects.create(name="Miami")
await Track.objects.create(title="Vice City", album=album, play_count=10)
await Track.objects.create(title="Beach Sand", album=album, play_count=20)
await Track.objects.create(title="Night Lights", album=album)

album = await Album.objects.fields("name").get()
assert album.is_best_seller is None

album = await Album.objects.exclude_fields({"is_best_seller", "id"}).get()
assert album.is_best_seller is None

album = await Album.objects.exclude_fields({"is_best_seller": ...}).get()
assert album.is_best_seller is None

tracks = await Track.objects.all()
for track in tracks:
assert track.play_count is not None
assert track.position is not None

album = (
await Album.objects.select_related("tracks")
.exclude_fields({"is_best_seller": ..., "tracks": {"play_count"}})
.get(name="Miami")
)
assert album.is_best_seller is None
assert len(album.tracks) == 3
for track in album.tracks:
assert track.play_count is None
assert track.position is not None

album = (
await Album.objects.select_related("tracks")
.exclude_fields(
{
"is_best_seller": ...,
"tracks": {"play_count": ..., "position": ...},
}
)
.get(name="Miami")
)
assert album.is_best_seller is None
assert len(album.tracks) == 3
for track in album.tracks:
assert track.play_count is None
assert track.position is None

album = (
await Album.objects.select_related("tracks")
.exclude_fields(
{"is_best_seller": ..., "tracks": {"play_count", "position"}}
)
.get(name="Miami")
)
assert album.is_best_seller is None
assert len(album.tracks) == 3
for track in album.tracks:
assert track.play_count is None
assert track.position is None

0 comments on commit 40254b9

Please sign in to comment.