Skip to content

Commit

Permalink
Merge pull request #21 from RSS-Engineering/add_count
Browse files Browse the repository at this point in the history
Add `count()` for dynamodb backends
  • Loading branch information
explorigin authored Nov 15, 2023
2 parents 91839de + b6d23cb commit 5dd305d
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 2 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ NOTE: Rule complexity is limited by the querying capabilities of the backend.
- Providing a `limit` parameter will limit the number of results. If more results remain, the returned dataset will have an `last_evaluated_key` property that can be passed to `exclusive_start_key` to continue with the next page.
- Providing `order='desc'` will return the result set in descending order. This is not available for query calls that "scan" dynamodb.

`count(query_expr: Optional[Rule], exclusive_start_key: Optional[tuple[Any]], order: str = 'asc'`
- Same as `query` but returns an integer count as total. (When calling `query` with a limit, the count dynamodb returns is <= the limit you provide)


## Backend Configuration Members

`hash_key` - the name of the key field for the backend table
Expand Down
26 changes: 25 additions & 1 deletion pydanticrud/backends/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def query(
limit: Optional[int] = None,
exclusive_start_key: Optional[str] = None,
order: str = "asc",
select: Optional[str] = None,
):
table = self.get_table()
f_expr, _ = rule_to_boto_expression(filter_expr) if filter_expr else (None, set())
Expand Down Expand Up @@ -348,6 +349,9 @@ def query(
if order != "asc":
params["ScanIndexForward"] = False

if select:
params["Select"] = select

if index_name:
params["IndexName"] = index_name
elif not keys_used.issubset({self.hash_key, self.range_key}):
Expand Down Expand Up @@ -376,8 +380,28 @@ def query(
raise e

return DynamoIterableResult(
self.cls, resp, (self.serializer.deserialize_record(rec) for rec in resp["Items"])
self.cls,
resp,
(self.serializer.deserialize_record(rec) for rec in resp.get("Items", [])),
)

def count(
self,
query_expr: Optional[Rule] = None,
exclusive_start_key: Optional[str] = None,
order: str = "asc",
) -> int:
"""
Dynamo Query returns a full "scanned_count" but when a limit is specified this count is <= the limit. To
get a full count (i.e. for pagination), a limitless query must be run.
"""
result = self.query(
query_expr=query_expr,
exclusive_start_key=exclusive_start_key,
order=order,
select="COUNT",
)
return result.scanned_count

def get(self, key: Union[Dict, Any]):
if isinstance(key, dict):
Expand Down
4 changes: 4 additions & 0 deletions pydanticrud/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def query(cls, *args, **kwargs):
res = IterableResult(cls, res)
return res

@classmethod
def count(cls, *args, **kwargs):
return cls.__backend__.count(*args, **kwargs)

@classmethod
def get(cls, *args, **kwargs):
return cls.parse_obj(cls.__backend__.get(*args, **kwargs))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pydanticrud"
version = "0.4.1"
version = "0.4.2"
description = "Supercharge your Pydantic models with CRUD methods and a pluggable backend"
authors = ["Timothy Farrell <[email protected]>"]
license = "MIT"
Expand Down
7 changes: 7 additions & 0 deletions tests/test_dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,13 @@ def test_pagination_query_with_index_complex(dynamo, complex_query_data):
assert all([r in check_data for r in res])
assert len(res) == page_size

def test_pagination_query_count(dynamo, complex_query_data):
page_size = 2
middle_record = complex_query_data[(len(complex_query_data)//2)]
query_rule = Rule(f"account == '{middle_record['account']}' and category_id >= {middle_record['category_id']}")
check_data = ComplexKeyModel.query(query_rule)
res_count = ComplexKeyModel.count(query_rule)
assert res_count == check_data.scanned_count

def test_query_errors_with_nonprimary_key_complex(dynamo, complex_query_data):
data_by_expires = complex_query_data[:]
Expand Down

0 comments on commit 5dd305d

Please sign in to comment.