Skip to content

Commit e64a6c8

Browse files
committed
Added tests for predefined row factories
Signed-off-by: chandr-andr (Kiselev Aleksandr) <[email protected]>
1 parent 331e643 commit e64a6c8

File tree

2 files changed

+90
-1
lines changed

2 files changed

+90
-1
lines changed

python/psqlpy/_internal/row_factories.pyi

+22-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,28 @@ def tuple_row(row: dict[str, Any]) -> Tuple[Tuple[str, Any]]:
2727
"""
2828

2929
class class_row(Generic[_CustomClass]): # noqa: N801
30-
"""Row converter to specified class."""
30+
"""Row converter to specified class.
31+
32+
### Example:
33+
```python
34+
from psqlpy.row_factories import class_row
35+
36+
37+
class ValidationModel:
38+
name: str
39+
views_count: int
40+
41+
42+
async def main:
43+
res = await db_pool.execute(
44+
"SELECT * FROM users",
45+
)
46+
47+
results: list[ValidationModel] = res.row_factory(
48+
class_row(ValidationModel),
49+
)
50+
```
51+
"""
3152

3253
def __init__(self: Self, class_: Type[_CustomClass]) -> None:
3354
"""Construct new `class_row`.

python/tests/test_row_factories.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from dataclasses import dataclass
2+
from typing import Any, Callable, Dict, Type
3+
4+
import pytest
5+
6+
from psqlpy import ConnectionPool
7+
from psqlpy.row_factories import class_row, tuple_row
8+
9+
pytestmark = pytest.mark.anyio
10+
11+
12+
async def test_tuple_row(
13+
psql_pool: ConnectionPool,
14+
table_name: str,
15+
number_database_records: int,
16+
) -> None:
17+
conn_result = await psql_pool.execute(
18+
querystring=f"SELECT * FROM {table_name}",
19+
)
20+
tuple_res = conn_result.row_factory(row_factory=tuple_row)
21+
22+
assert len(tuple_res) == number_database_records
23+
assert isinstance(tuple_res[0], tuple)
24+
25+
26+
async def test_class_row(
27+
psql_pool: ConnectionPool,
28+
table_name: str,
29+
number_database_records: int,
30+
) -> None:
31+
@dataclass
32+
class ValidationTestModel:
33+
id: int
34+
name: str
35+
36+
conn_result = await psql_pool.execute(
37+
querystring=f"SELECT * FROM {table_name}",
38+
)
39+
class_res = conn_result.row_factory(row_factory=class_row(ValidationTestModel))
40+
assert len(class_res) == number_database_records
41+
assert isinstance(class_res[0], ValidationTestModel)
42+
43+
44+
async def test_custom_row_factory(
45+
psql_pool: ConnectionPool,
46+
table_name: str,
47+
number_database_records: int,
48+
) -> None:
49+
@dataclass
50+
class ValidationTestModel:
51+
id: int
52+
name: str
53+
54+
def to_class(
55+
class_: Type[ValidationTestModel],
56+
) -> Callable[[Dict[str, Any]], ValidationTestModel]:
57+
def to_class_inner(row: Dict[str, Any]) -> ValidationTestModel:
58+
return class_(**row)
59+
60+
return to_class_inner
61+
62+
conn_result = await psql_pool.execute(
63+
querystring=f"SELECT * FROM {table_name}",
64+
)
65+
class_res = conn_result.row_factory(row_factory=to_class(ValidationTestModel))
66+
67+
assert len(class_res) == number_database_records
68+
assert isinstance(class_res[0], ValidationTestModel)

0 commit comments

Comments
 (0)