diff --git a/redisearch/client.py b/redisearch/client.py index 80971b4..fda9ecf 100644 --- a/redisearch/client.py +++ b/redisearch/client.py @@ -1,3 +1,4 @@ +from typing import Dict, Union from redis import Redis, ConnectionPool import itertools import time @@ -501,7 +502,17 @@ def info(self): it = six.moves.map(to_string, res) return dict(six.moves.zip(it, it)) - def _mk_query_args(self, query): + def get_params_args(self, params: Dict[str, Union[str, int, float]]): + args = [] + if len(params) > 0: + args.append("PARAMS") + args.append(len(params)*2) + for key, value in params.items(): + args.append(key) + args.append(value) + return args + + def _mk_query_args(self, query, query_params): args = [self.index_name] if isinstance(query, six.string_types): @@ -509,11 +520,12 @@ def _mk_query_args(self, query): query = Query(query) if not isinstance(query, Query): raise ValueError("Bad query type %s" % type(query)) - args += query.get_args() + if query_params is not None: + args+= self.get_params_args(query_params) return args, query - def search(self, query): + def search(self, query, query_params: Dict[str, Union[str, int, float]] = None): """ Search the index for a given query, and return a result of documents @@ -522,7 +534,7 @@ def search(self, query): - **query**: the search query. Either a text for simple queries with default parameters, or a Query object for complex queries. See RediSearch's documentation on query format """ - args, query = self._mk_query_args(query) + args, query = self._mk_query_args(query, query_params=query_params) st = time.time() res = self.redis.execute_command(self.SEARCH_CMD, *args) @@ -532,11 +544,11 @@ def search(self, query): has_payload=query._with_payloads, with_scores=query._with_scores) - def explain(self, query): - args, query_text = self._mk_query_args(query) + def explain(self, query, query_params: Dict[str, Union[str, int, float]] = None): + args, query_text = self._mk_query_args(query, query_params=query_params) return self.redis.execute_command(self.EXPLAIN_CMD, *args) - def aggregate(self, query): + def aggregate(self, query, query_params: Dict[str, Union[str, int, float]] = None): """ Issue an aggregation query @@ -556,7 +568,8 @@ def aggregate(self, query): self.index_name] + query.build_args() else: raise ValueError('Bad query', query) - + if query_params is not None: + cmd+= self.get_params_args(query_params) raw = self.redis.execute_command(*cmd) if has_cursor: if isinstance(query, Cursor): diff --git a/redisearch/query.py b/redisearch/query.py index f741a6d..e03d2a4 100644 --- a/redisearch/query.py +++ b/redisearch/query.py @@ -209,6 +209,7 @@ def get_args(self): args += self._summarize_fields + self._highlight_fields args += ["LIMIT", self._offset, self._num] + return args def paging(self, offset, num): @@ -288,7 +289,6 @@ def sort_by(self, field, asc=True): self._sortby = SortbyField(field, asc) return self - class Filter(object): def __init__(self, keyword, field, *args): diff --git a/test/test.py b/test/test.py index 6253a43..f2b9f4c 100644 --- a/test/test.py +++ b/test/test.py @@ -1189,6 +1189,69 @@ def testSearchReturnFields(self): self.assertEqual('doc:1', total[0].id) self.assertEqual('telmatosaurus', total[0].txt) + def test_text_params(self): + conn = self.redis() + + with conn as r: + # Creating a client with a given index name + client = Client('idx', port=conn.port) + client.redis.flushdb() + client.create_index((TextField('name'),)) + + client.add_document('doc1', name='Alice') + client.add_document('doc2', name='Bob') + client.add_document('doc3', name='Carol') + + params_dict = {"name1":"Alice", "name2":"Bob"} + q = Query("@name:($name1 | $name2 )") + res = client.search(q, query_params=params_dict) + self.assertEqual(2, res.total) + self.assertEqual('doc1', res.docs[0].id) + self.assertEqual('doc2', res.docs[1].id) + + + def test_numeric_params(self): + conn = self.redis() + + with conn as r: + # Creating a client with a given index name + client = Client('idx', port=conn.port) + client.redis.flushdb() + client.create_index((NumericField('numval'),)) + + client.add_document('doc1', numval=101) + client.add_document('doc2', numval=102) + client.add_document('doc3', numval=103) + + params_dict = {"min":101, "max":102} + q = Query('@numval:[$min $max]') + res = client.search(q, query_params=params_dict) + self.assertEqual(2, res.total) + self.assertEqual('doc1', res.docs[0].id) + self.assertEqual('doc2', res.docs[1].id) + + def test_geo_params(self): + conn = self.redis() + + with conn as r: + # Creating a client with a given index name + client = Client('idx', port=conn.port) + client.redis.flushdb() + client.create_index((GeoField('g'),)) + + client.add_document('doc1', g='29.69465, 34.95126') + client.add_document('doc2', g='29.69350, 34.94737') + client.add_document('doc3', g='29.68746, 34.94882') + + params_dict = {"lat":'34.95126', "lon":'29.69465', "radius":10, "units":"km"} + q = Query('@g:[$lon $lat $radius $units]') + res = client.search(q, query_params=params_dict) + self.assertEqual(3, res.total) + self.assertEqual('doc1', res.docs[0].id) + self.assertEqual('doc2', res.docs[1].id) + self.assertEqual('doc3', res.docs[2].id) + + if __name__ == '__main__': unittest.main()