diff --git a/docs/usage-querybuilder.md b/docs/usage-querybuilder.md index 60f73d686..6005e7455 100644 --- a/docs/usage-querybuilder.md +++ b/docs/usage-querybuilder.md @@ -68,3 +68,33 @@ for i in range(0, 10): ``` Note that the results are different, because we've placed more weight on the term `hubble`. + + +String queries can also be created, without specifying each term separately. + +For example: + +```python +query = querybuilder.get_standard_query('hubble space telescope') + +hits = searcher.search(query) + +for i in range(0, 10): + print(f'{i+1:2} {hits[i].docid:15} {hits[i].score:.5f}') +``` + +It is also possible to use string queries for specifying boosted terms as indicated in +Lucene's [query parser syntax documentation](https://lucene.apache.org/core/2_9_4/queryparsersyntax.html). + +For example: + +```python +query = querybuilder.get_standard_query('hubble^2 space^1 telescope^1') +hits = searcher.search(query) + +for i in range(0, 10): + print(f'{i+1:2} {hits[i].docid:15} {hits[i].score:.5f}') +``` + +The results are the same as specifying boost scores for each term using `querybuilder.get_boost_query`. + diff --git a/pyserini/search/lucene/querybuilder.py b/pyserini/search/lucene/querybuilder.py index 7627121c2..122faab89 100644 --- a/pyserini/search/lucene/querybuilder.py +++ b/pyserini/search/lucene/querybuilder.py @@ -28,9 +28,16 @@ # Wrapper around Lucene clases JTerm = autoclass('org.apache.lucene.index.Term') +JQuery = autoclass('org.apache.lucene.search.Query') JBooleanClause = autoclass('org.apache.lucene.search.BooleanClause') JBoostQuery = autoclass('org.apache.lucene.search.BoostQuery') JTermQuery = autoclass('org.apache.lucene.search.TermQuery') +JWildcardQuery = autoclass('org.apache.lucene.search.WildcardQuery') +JFuzzyQuery = autoclass('org.apache.lucene.search.FuzzyQuery') +JPrefixQuery = autoclass('org.apache.lucene.search.PrefixQuery') +JStandardQueryParser = autoclass('org.apache.lucene.queryparser.flexible.standard.StandardQueryParser') +JComplexPhraseQueryParser = autoclass('org.apache.lucene.queryparser.complexPhrase.ComplexPhraseQueryParser') + # Wrappers around Anserini classes JQueryGeneratorUtils = autoclass('io.anserini.search.query.QueryGeneratorUtils') @@ -88,3 +95,44 @@ def get_boost_query(query, boost): JBoostQuery """ return JBoostQuery(query, boost) + + +def get_standard_query(query, field="contents", analyzer=get_lucene_analyzer()): + """Runs Lucene's StandardQueryParser to get a query. + + Parameters + ---------- + query : str + The query term string. + field : str + Field to search. + analyzer : Analyzer + Analyzer to use for tokenizing the query term. + + Returns + ------- + JQuery + """ + query_parser = JStandardQueryParser() + query_parser.setAnalyzer(analyzer) + return query_parser.parse(query, field) + + +def get_complex_phrase_query(query, field="contents", analyzer=get_lucene_analyzer()): + """Runs Lucene's ComplexPhraseQueryParser to get a query. + + Parameters + ---------- + query : str + The query term string. + field : str + Field to search. + analyzer : Analyzer + Analyzer to use for tokenizing the query term. + + Returns + ------- + JQuery + """ + query_parser = JComplexPhraseQueryParser(field, analyzer) + return query_parser.parse(query) diff --git a/tests/test_querybuilder.py b/tests/test_querybuilder.py index 0ef17481a..740aea006 100644 --- a/tests/test_querybuilder.py +++ b/tests/test_querybuilder.py @@ -26,20 +26,34 @@ class TestQueryBuilding(unittest.TestCase): - def setUp(self): + @classmethod + def setUp(cls): # Download pre-built CACM index built using Lucene 9; append a random value to avoid filename clashes. r = randint(0, 10000000) - self.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene9-index.cacm.tar.gz' - self.tarball_name = 'lucene-index.cacm-{}.tar.gz'.format(r) - self.index_dir = 'index{}/'.format(r) + cls.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene9-index.cacm.tar.gz' + cls.tarball_name = 'lucene-index.cacm-{}.tar.gz'.format(r) + cls.index_dir = 'index{}/'.format(r) - filename, headers = urlretrieve(self.collection_url, self.tarball_name) + filename, headers = urlretrieve(cls.collection_url, cls.tarball_name) - tarball = tarfile.open(self.tarball_name) - tarball.extractall(self.index_dir) + tarball = tarfile.open(cls.tarball_name) + tarball.extractall(cls.index_dir) tarball.close() - self.searcher = LuceneSearcher(f'{self.index_dir}lucene9-index.cacm') + cls.searcher = LuceneSearcher(f'{cls.index_dir}lucene9-index.cacm') + + # Create index without document vectors + # The current directory depends on if you're running inside an IDE or from command line. + curdir = os.getcwd() + corpus_path = 'tests/resources/sample_collection_jsonl' + corpus_path = '../' + corpus_path if curdir.endswith('tests') else corpus_path + + cls.no_vec_index_dir = 'no_vec_index' + cmd1 = f'python -m pyserini.index.lucene -collection JsonCollection ' + \ + f'-generator DefaultLuceneDocumentGenerator ' + \ + f'-threads 1 -input {corpus_path} -index {cls.no_vec_index_dir} -storePositions' + os.system(cmd1) + cls.no_vec_searcher = LuceneSearcher(cls.no_vec_index_dir) def testBuildBoostedQuery(self): term_query1 = querybuilder.get_term_query('information') @@ -95,6 +109,40 @@ def testTermQuery(self): self.assertEqual(h1.docid, h2.docid) self.assertEqual(h1.score, h2.score) + def testStandardQuery(self): + should = querybuilder.JBooleanClauseOccur['should'].value + must_not = querybuilder.JBooleanClauseOccur['must_not'].value + + query_builder = querybuilder.get_boolean_query_builder() + query_builder.add(querybuilder.get_standard_query('"contents document"~3'), must_not) + query_builder.add(querybuilder.get_standard_query('document'), should) + query = query_builder.build() + hit, = self.no_vec_searcher.search(query) + self.assertEqual(hit.docid, "doc3") + + query_builder = querybuilder.get_boolean_query_builder() + # Standard query parser doesn't support nested logic, so the asterisk is ignored. + query_builder.add(querybuilder.get_standard_query('"contents doc*"~3'), must_not) + query_builder.add(querybuilder.get_standard_query('document'), should) + query = query_builder.build() + doc_ids = ["doc2", "doc3"] + hits = self.no_vec_searcher.search(query) + self.assertEqual(len(hits), len(doc_ids)) + for hit, doc_id in zip(hits, doc_ids): + self.assertEqual(hit.docid, doc_id) + + def testComplexPhraseQuery(self): + should = querybuilder.JBooleanClauseOccur['should'].value + must_not = querybuilder.JBooleanClauseOccur['must_not'].value + + query_builder = querybuilder.get_boolean_query_builder() + # Complex phrase query parser supports nested logic + query_builder.add(querybuilder.get_complex_phrase_query('"contents doc*"~3'), must_not) + query_builder.add(querybuilder.get_standard_query('document'), should) + query = query_builder.build() + hit, = self.no_vec_searcher.search(query) + self.assertEqual(hit.docid, "doc3") + def testIncompatabilityWithRM3(self): should = querybuilder.JBooleanClauseOccur['should'].value query_builder = querybuilder.get_boolean_query_builder() @@ -129,10 +177,14 @@ def testTermQuery2(self): self.assertEqual(h1.docid, h2.docid) self.assertEqual(h1.score, h2.score) - def tearDown(self): - self.searcher.close() - os.remove(self.tarball_name) - shutil.rmtree(self.index_dir) + @classmethod + def tearDown(cls): + cls.searcher.close() + cls.no_vec_searcher.close() + os.remove(cls.tarball_name) + shutil.rmtree(cls.index_dir) + shutil.rmtree(cls.no_vec_index_dir) + if __name__ == '__main__':