Skip to content

Commit 80ccafc

Browse files
authored
Merge pull request #55 from qaspen-python/feature/add_predefined_row_factories
Added predefined row factories
2 parents e40c54a + e64a6c8 commit 80ccafc

File tree

5 files changed

+192
-0
lines changed

5 files changed

+192
-0
lines changed
+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from typing import Any, Generic, Tuple, Type, TypeVar
2+
3+
from typing_extensions import Self
4+
5+
_CustomClass = TypeVar(
6+
"_CustomClass",
7+
)
8+
9+
def tuple_row(row: dict[str, Any]) -> Tuple[Tuple[str, Any]]:
10+
"""Convert dict row into tuple row.
11+
12+
### Parameters:
13+
- `row`: row in dictionary.
14+
15+
### Returns:
16+
row as a tuple of tuples.
17+
18+
### Example:
19+
```
20+
dict_ = {
21+
"psqlpy": "is",
22+
"postgresql": "driver",
23+
}
24+
# This function will convert this dict into:
25+
(("psqlpy", "is"), ("postgresql": "driver"))
26+
```
27+
"""
28+
29+
class class_row(Generic[_CustomClass]): # noqa: N801
30+
"""Row converter to specified class.
31+
32+
### Example:
33+
```python
34+
from psqlpy.row_factories import class_row
35+
36+
37+
class ValidationModel:
38+
name: str
39+
views_count: int
40+
41+
42+
async def main:
43+
res = await db_pool.execute(
44+
"SELECT * FROM users",
45+
)
46+
47+
results: list[ValidationModel] = res.row_factory(
48+
class_row(ValidationModel),
49+
)
50+
```
51+
"""
52+
53+
def __init__(self: Self, class_: Type[_CustomClass]) -> None:
54+
"""Construct new `class_row`.
55+
56+
### Parameters:
57+
- `class_`: class to transform row into.
58+
"""
59+
def __call__(self, row: dict[str, Any]) -> _CustomClass:
60+
"""Convert row into specified class.
61+
62+
### Parameters:
63+
- `row`: row in dictionary.
64+
65+
### Returns:
66+
Constructed specified class.
67+
"""

python/psqlpy/row_factories.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from ._internal.row_factories import class_row, tuple_row
2+
3+
__all__ = [
4+
"tuple_row",
5+
"class_row",
6+
]

python/tests/test_row_factories.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from dataclasses import dataclass
2+
from typing import Any, Callable, Dict, Type
3+
4+
import pytest
5+
6+
from psqlpy import ConnectionPool
7+
from psqlpy.row_factories import class_row, tuple_row
8+
9+
pytestmark = pytest.mark.anyio
10+
11+
12+
async def test_tuple_row(
13+
psql_pool: ConnectionPool,
14+
table_name: str,
15+
number_database_records: int,
16+
) -> None:
17+
conn_result = await psql_pool.execute(
18+
querystring=f"SELECT * FROM {table_name}",
19+
)
20+
tuple_res = conn_result.row_factory(row_factory=tuple_row)
21+
22+
assert len(tuple_res) == number_database_records
23+
assert isinstance(tuple_res[0], tuple)
24+
25+
26+
async def test_class_row(
27+
psql_pool: ConnectionPool,
28+
table_name: str,
29+
number_database_records: int,
30+
) -> None:
31+
@dataclass
32+
class ValidationTestModel:
33+
id: int
34+
name: str
35+
36+
conn_result = await psql_pool.execute(
37+
querystring=f"SELECT * FROM {table_name}",
38+
)
39+
class_res = conn_result.row_factory(row_factory=class_row(ValidationTestModel))
40+
assert len(class_res) == number_database_records
41+
assert isinstance(class_res[0], ValidationTestModel)
42+
43+
44+
async def test_custom_row_factory(
45+
psql_pool: ConnectionPool,
46+
table_name: str,
47+
number_database_records: int,
48+
) -> None:
49+
@dataclass
50+
class ValidationTestModel:
51+
id: int
52+
name: str
53+
54+
def to_class(
55+
class_: Type[ValidationTestModel],
56+
) -> Callable[[Dict[str, Any]], ValidationTestModel]:
57+
def to_class_inner(row: Dict[str, Any]) -> ValidationTestModel:
58+
return class_(**row)
59+
60+
return to_class_inner
61+
62+
conn_result = await psql_pool.execute(
63+
querystring=f"SELECT * FROM {table_name}",
64+
)
65+
class_res = conn_result.row_factory(row_factory=to_class(ValidationTestModel))
66+
67+
assert len(class_res) == number_database_records
68+
assert isinstance(class_res[0], ValidationTestModel)

src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ pub mod driver;
44
pub mod exceptions;
55
pub mod extra_types;
66
pub mod query_result;
7+
pub mod row_factories;
78
pub mod runtime;
89
pub mod value_converter;
910

1011
use common::add_module;
1112
use exceptions::python_errors::python_exceptions_module;
1213
use extra_types::extra_types_module;
1314
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, Bound, PyResult, Python};
15+
use row_factories::row_factories_module;
1416

1517
#[pymodule]
1618
#[pyo3(name = "_internal")]
@@ -32,5 +34,6 @@ fn psqlpy(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> {
3234
pymod.add_class::<query_result::PSQLDriverSinglePyQueryResult>()?;
3335
add_module(py, pymod, "extra_types", extra_types_module)?;
3436
add_module(py, pymod, "exceptions", python_exceptions_module)?;
37+
add_module(py, pymod, "row_factories", row_factories_module)?;
3538
Ok(())
3639
}

src/row_factories.rs

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
use pyo3::{
2+
pyclass, pyfunction, pymethods,
3+
types::{PyDict, PyDictMethods, PyModule, PyModuleMethods, PyTuple},
4+
wrap_pyfunction, Bound, Py, PyAny, PyResult, Python, ToPyObject,
5+
};
6+
7+
use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult};
8+
9+
#[pyfunction]
10+
#[allow(clippy::needless_pass_by_value)]
11+
fn tuple_row(py: Python<'_>, dict_: Py<PyAny>) -> RustPSQLDriverPyResult<Py<PyAny>> {
12+
let dict_ = dict_.downcast_bound::<PyDict>(py).map_err(|_| {
13+
RustPSQLDriverError::RustToPyValueConversionError(
14+
"as_tuple accepts only dict as a parameter".into(),
15+
)
16+
})?;
17+
Ok(PyTuple::new_bound(py, dict_.items()).to_object(py))
18+
}
19+
20+
#[pyclass]
21+
#[allow(non_camel_case_types)]
22+
struct class_row(Py<PyAny>);
23+
24+
#[pymethods]
25+
impl class_row {
26+
#[new]
27+
fn constract_class(class_: Py<PyAny>) -> Self {
28+
Self(class_)
29+
}
30+
31+
#[allow(clippy::needless_pass_by_value)]
32+
fn __call__(&self, py: Python<'_>, dict_: Py<PyAny>) -> RustPSQLDriverPyResult<Py<PyAny>> {
33+
let dict_ = dict_.downcast_bound::<PyDict>(py).map_err(|_| {
34+
RustPSQLDriverError::RustToPyValueConversionError(
35+
"as_tuple accepts only dict as a parameter".into(),
36+
)
37+
})?;
38+
Ok(self.0.call_bound(py, (), Some(dict_))?)
39+
}
40+
}
41+
42+
#[allow(clippy::module_name_repetitions)]
43+
#[allow(clippy::missing_errors_doc)]
44+
pub fn row_factories_module(_py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> {
45+
pymod.add_function(wrap_pyfunction!(tuple_row, pymod)?)?;
46+
pymod.add_class::<class_row>()?;
47+
Ok(())
48+
}

0 commit comments

Comments
 (0)