Skip to content

Commit

Permalink
Added aggregation results in QueryResultsWrapper (#9)
Browse files Browse the repository at this point in the history
* Added aggregation results in QueryResultsWrapper

---------

Co-authored-by: Maxim <[email protected]>
  • Loading branch information
maximbogatyrev and bogatyrev-maxim authored Oct 8, 2024
1 parent 5629795 commit 04c11a2
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 1 deletion.
3 changes: 3 additions & 0 deletions pyreindexer/lib/include/queryresults_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class QueryResultsWrapper {
db_->FetchResults(*this);
}

const std::vector<reindexer::AggregationResult>& GetAggregationResults() const& { return qresPtr.GetAggregationResults(); }
const std::vector<reindexer::AggregationResult>& GetAggregationResults() const&& = delete;

private:
friend DBInterface;

Expand Down
33 changes: 33 additions & 0 deletions pyreindexer/lib/src/rawpyreindexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,4 +428,37 @@ static PyObject *QueryResultsWrapperDelete(PyObject *self, PyObject *args) {

Py_RETURN_NONE;
}

static PyObject *GetAggregationResults(PyObject *self, PyObject *args) {
uintptr_t qresWrapperAddr;

if (!PyArg_ParseTuple(args, "k", &qresWrapperAddr)) {
return NULL;
}

QueryResultsWrapper *qresWrapper = getQueryResultsWrapper(qresWrapperAddr);

const auto &aggResults = qresWrapper->GetAggregationResults();
WrSerializer wrSer;
wrSer << "[";
for (size_t i = 0; i < aggResults.size(); ++i) {
if (i > 0) {
wrSer << ',';
}
aggResults[i].GetJSON(wrSer);
}
wrSer << "]";

PyObject *dictFromJson = nullptr;
try {
dictFromJson = PyObjectFromJson(reindexer::giftStr(wrSer.Slice())); // stolen ref
} catch (const Error &err) {
Py_XDECREF(dictFromJson);

return Py_BuildValue("is{}", err.code(), err.what().c_str());
}

return Py_BuildValue("isO", errOK, "", dictFromJson);
}

} // namespace pyreindexer
2 changes: 2 additions & 0 deletions pyreindexer/lib/src/rawpyreindexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ static PyObject *EnumNamespaces(PyObject *self, PyObject *args);

static PyObject *QueryResultsWrapperIterate(PyObject *self, PyObject *args);
static PyObject *QueryResultsWrapperDelete(PyObject *self, PyObject *args);
static PyObject *GetAggregationResults(PyObject *self, PyObject *args);

// clang-format off
static PyMethodDef module_methods[] = {
Expand All @@ -98,6 +99,7 @@ static PyMethodDef module_methods[] = {

{"query_results_iterate", QueryResultsWrapperIterate, METH_VARARGS, "get query result"},
{"query_results_delete", QueryResultsWrapperDelete, METH_VARARGS, "free query results buffer"},
{"get_agg_results", GetAggregationResults, METH_VARARGS, "get aggregation results"},

{nullptr, nullptr, 0, nullptr}
};
Expand Down
12 changes: 12 additions & 0 deletions pyreindexer/query_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,15 @@ def _close_iterator(self):

self.qres_iter_count = 0
self.api.query_results_delete(self.qres_wrapper_ptr)


def get_agg_results(self):
"""Returns aggregation results for the current query
"""

self.err_code, self.err_msg, res = self.api.get_agg_results(
self.qres_wrapper_ptr)
if self.err_code:
raise Exception(self.err_msg)
return res
17 changes: 17 additions & 0 deletions pyreindexer/tests/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,20 @@ def test_sql_select_with_syntax_error(self, namespace, index, item):
assert_that(calling(sql_query).with_args(namespace, query),
raises(Exception, matching=has_string(string_contains_in_order(
"Expected", "but found"))), "Error wasn't raised when syntax was incorrect")

def test_sql_select_with_aggregations(self, namespace, index, items):
# Given("Create namespace with item")
db, namespace_name = namespace
# When ("Insert items into namespace")
for _ in range(5):
db.item_insert(namespace_name, {"id": 100}, ["id=serial()"])

select_result = db.select(f'SELECT min(id), max(id), avg(id) FROM {namespace_name}').get_agg_results()
assert_that(len(select_result), 3, "The aggregation result must contain 3 elements")

expected_values = {"min":1,"max":10,"avg":5.5}

# Then ("Check that returned agg results are correct")
for agg in select_result:
assert_that(agg['value'], equal_to(expected_values[agg['type']]),
f"Incorrect aggregation result for {agg['type']}")
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def build_cmake(self, ext):


setup(name=PACKAGE_NAME,
version='0.2.36',
version='0.2.37',
description='A connector that allows to interact with Reindexer',
author='Igor Tulmentyev',
author_email='[email protected]',
Expand Down

0 comments on commit 04c11a2

Please sign in to comment.