Skip to content

Commit

Permalink
Merge branch 'master' of github.com:micmurawski/pynamodb-utils
Browse files Browse the repository at this point in the history
  • Loading branch information
micmurawski committed Jan 30, 2024
2 parents 06aa2de + 849e4d5 commit 5ea0129
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 12 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def read(*parts):

setup(
name="pynamodb_utils",
version="1.3.7",
version="1.4.0",
author="Michal Murawski",
author_email="[email protected]",
description="Utilities package for pynamodb.",
Expand All @@ -27,7 +27,7 @@ def read(*parts):
'tests',
)
),
install_requires=["pynamodb>=5.0.0,<6.0.0"],
install_requires=["pynamodb>=6.0.0,<7.0.0"],
include_package_data=True,
python_requires=">=3.6",
license="MIT",
Expand Down
32 changes: 26 additions & 6 deletions src/pynamodb_utils/attributes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from enum import Enum
from typing import Collection, FrozenSet, Optional, Union

Expand All @@ -12,6 +13,12 @@ class DynamicMapAttribute(MapAttribute):
element_type = None

def __init__(self, *args, of=None, **kwargs):
if "default" in kwargs:
kwargs["default"] = json.dumps(kwargs["default"])

if "default_for_new" in kwargs:
kwargs["default_for_new"] = json.dumps(kwargs["default_for_new"])

if of:
if not issubclass(of, MapAttribute):
raise ValueError("'of' must be subclass of MapAttribute")
Expand Down Expand Up @@ -49,9 +56,6 @@ def keys(self) -> FrozenSet:
def items(self) -> Collection:
return self.as_dict().items()

def __repr__(self) -> dict:
return self.as_dict()

def __str__(self) -> str:
return str(self.__class__)

Expand All @@ -61,20 +65,28 @@ class EnumNumberAttribute(NumberAttribute):

def __init__(
self,
enum,
enum: Enum,
hash_key: bool = False,
range_key: bool = False,
null: Optional[bool] = None,
default: Optional[Enum] = None,
default_for_new: Optional[Enum] = None,
attr_name: Optional[str] = None,
):
if isinstance(enum, Enum):
raise ValueError("enum must be Enum class")
self.enum = enum

if default_for_new is not None and not isinstance(default_for_new, enum):
raise ValueError(f"default_for_new is not instance of {enum}")
if default is not None and not isinstance(default, enum):
raise ValueError(f"default is not instance of {enum}")

super().__init__(
hash_key=hash_key,
range_key=range_key,
default=default.value if default else None,
default_for_new=default_for_new.value if default_for_new else None,
null=null,
attr_name=attr_name,
)
Expand All @@ -99,20 +111,28 @@ def deserialize(self, value: str) -> str:
class EnumUnicodeAttribute(UnicodeAttribute):
def __init__(
self,
enum,
enum: Enum,
hash_key: bool = False,
range_key: bool = False,
null: Optional[bool] = None,
default: Optional[Enum] = None,
default_for_new: Optional[Enum] = None,
attr_name: Optional[str] = None,
):
if isinstance(enum, Enum):
raise ValueError("enum must be Enum class")
self.enum = enum

if default_for_new is not None and not isinstance(default_for_new, enum):
raise ValueError(f"default_for_new is not instance of {enum}")
if default is not None and not isinstance(default, enum):
raise ValueError(f"default is not instance of {enum}")

super().__init__(
hash_key=hash_key,
range_key=range_key,
default=default,
default=default.value if default else None,
default_for_new=default_for_new.value if default_for_new else None,
null=null,
attr_name=attr_name,
)
Expand Down
17 changes: 13 additions & 4 deletions src/tests/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,30 @@ def test_general(post_table):
post = post_table
category_enum = post.category.enum

post = post(
post_1 = post(
name="A weekly news.",
sub_name="Shocking revelations",
content="Last week took place...",
category=category_enum.finance,
tags={"type": "news", "topics": ["stock exchange", "NYSE"]},
)
post.save()
post_1.save()
post_2 = post(
name="A boring news.",
sub_name="Nothing interesting...",
content="...",
category=category_enum.finance,
tags={"type": "not-news", "topics": ["stock exchange", "LSE"]},
)
post_2.save()
query = {
"created_at__lte": str(datetime.now()),
"sub_name__exists": None,
"category__equals": "finance",
"OR": {"tags.type__equals": "news", "tags.topics__contains": ["NYSE"]},
}

results = post.make_index_query(query)
results = list(post.make_index_query(query))

expected = {
"content": "Last week took place...",
Expand All @@ -39,7 +47,8 @@ def test_general(post_table):
"updated_at": "2019-01-01T00:00:00+00:00",
}

assert next(results).as_dict() == expected
assert len(results) == 1
assert results[0].as_dict() == expected


def test_bad_field(post_table):
Expand Down

0 comments on commit 5ea0129

Please sign in to comment.