Skip to content

Commit 5207ff0

Browse files
authored
Add bm25 service (primeqa#402)
* Add BM25 retriever and indexer to components and services * add get_engine_type api * save engine_type in information.json * Bump version: 0.10.0 → 0.11.0
1 parent 2930c3d commit 5207ff0

File tree

16 files changed

+147
-73
lines changed

16 files changed

+147
-73
lines changed

.bumpversion.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.10.0
2+
current_version = 0.11.0
33
commit = True
44

55
[bumpversion:file:VERSION]

VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.10.0
1+
0.11.0

primeqa/ir/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ The following command builds an index for BM25 retrieval.
169169
python primeqa/ir/run_ir.py \
170170
--do_index \
171171
--engine_type BM25 \
172-
--corpus_path <document_collection> \
172+
--collection <document_collection> \
173173
--index_path <index_dir>
174174
--threads <num_threads>
175175
```

primeqa/ir/sparse/bm25_engine.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@ class BM25Engine:
1212
def __init__(self, config: BM25Config):
1313
self.config = config
1414
logger.info(f"Running BM25")
15+
logger.info(config)
1516

1617
def do_index(self):
1718
logger.info("Running BM25 indexing")
1819
indexer = PyseriniIndexer()
19-
rc = indexer.index_collection(self.config.corpus_path, self.config.index_location,
20+
rc = indexer.index_collection(self.config.collection, self.config.index_location,
2021
self.config.fieldnames, self.config.overwrite,
2122
self.config.threads, self.config.additional_indexing_args )
2223
logger.info(f"BM25 Indexing finished with rc: {rc}")

primeqa/ir/sparse/config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
@dataclass
55
class IndexingArguments():
66

7-
index_path: str = field(default=None, metadata={"help":"Path to the index directory location"})
7+
index_location: str = field(default=None, metadata={"help":"Path to the index directory location"})
88

99
overwrite: bool = field(default=False, metadata={"help": "Overwrite existing directory"})
1010

11-
corpus_path: str = field(default=None, metadata={"help":"Path to a corpus tsv or json file or directory"})
11+
collection: str = field(default=None, metadata={"help":"Path to a corpus tsv or json file or directory"})
1212

1313
fieldnames: list = field(default=None, metadata={"help":"fields names to use to identify document_id, title, text if corpus tsv has no headings"})
1414

primeqa/ir/sparse/indexer.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def _run_command(self, cmd):
2929
rc = process.wait()
3030
return rc
3131

32-
def _preprocess_corpus(self, corpus_path, tmpdirname, fieldnames=None):
33-
reader = corpus_reader(corpus_path, fieldnames=fieldnames)
32+
def _preprocess_corpus(self, collection, tmpdirname, fieldnames=None):
33+
reader = corpus_reader(collection, fieldnames=fieldnames)
3434
outf = open( os.path.join(tmpdirname,"corpus_pyserini_fmt.jsonl"), 'w' )
3535
num_docs = 0
3636
for passage in tqdm(reader):
@@ -53,7 +53,7 @@ def _preprocess_corpus(self, corpus_path, tmpdirname, fieldnames=None):
5353
5454
5555
Args:
56-
corpus_path (str) : path to file or directory of documents in tsv or jsonl format.
56+
collection (str) : path to file or directory of documents in tsv or jsonl format.
5757
index_path (str) : output directory path where the index is written
5858
fieldnames ( List, Optional): column headers to be assigned to tsv without headers
5959
overwrite (bool, Optional): overwrite an existing directory, defaults to false
@@ -64,7 +64,7 @@ def _preprocess_corpus(self, corpus_path, tmpdirname, fieldnames=None):
6464
6565
6666
"""
67-
def index_collection(self, corpus_path: str, index_path: str, fieldnames=None, overwrite=False,
67+
def index_collection(self, collection: str, index_path: str, fieldnames=None, overwrite=False,
6868
threads=1, additional_index_cmd_args='--storePositions --storeDocvectors --storeRaw' ):
6969
if not overwrite and os.path.exists(index_path) and os.listdir(index_path) :
7070
raise ValueError(f"Index path not empty '{index_path}' and overwrite not specified")
@@ -73,7 +73,7 @@ def index_collection(self, corpus_path: str, index_path: str, fieldnames=None, o
7373
# create temporary subdirectory for the corpus
7474
with tempfile.TemporaryDirectory(prefix='tmp',dir=index_path) as tmpdirname:
7575
# convert corpus documents to pyserini jsonl
76-
num_docs = self._preprocess_corpus(corpus_path, tmpdirname, fieldnames=fieldnames)
76+
num_docs = self._preprocess_corpus(collection, tmpdirname, fieldnames=fieldnames)
7777
# build index command
7878
cmd1 = f'python -m pyserini.index.lucene -collection JsonCollection ' + \
7979
f'-generator DefaultLuceneDocumentGenerator ' + \
@@ -90,5 +90,6 @@ def index_collection(self, corpus_path: str, index_path: str, fieldnames=None, o
9090
logger.info(f"Index {index_path} contains {searcher.num_docs} documents")
9191
assert(searcher.num_docs == num_docs)
9292
logging.info(f"Index available at {index_path}")
93+
searcher.close()
9394
return rc
9495

primeqa/ir/sparse/retriever.py

+1-49
Original file line numberDiff line numberDiff line change
@@ -6,55 +6,7 @@
66

77
logger = logging.getLogger(__name__)
88

9-
class BaseRetriever(metaclass=ABCMeta):
10-
"""
11-
Base class for Retriever
12-
"""
13-
14-
@abstractmethod
15-
def retrieve(self, query: str, topK: Optional[int] = 10):
16-
"""
17-
18-
Run queries against the index to retrieve ranked list of documents
19-
Return documents that are most relevant to the query.
20-
21-
Args:
22-
query: search
23-
top_k: number of hits to return, defaults to 10
24-
25-
26-
Returns:
27-
List of hits, each hit is a dict containing :
28-
{
29-
"rank": i,
30-
"score": hit.score,
31-
"doc_id": docid,
32-
"title": title,
33-
"text": text
34-
}
35-
36-
37-
"""
38-
pass
39-
40-
@abstractmethod
41-
def batch_retrieve(self, queries: List[str], qids: List[str], topK: int = 10, threads: int = 1):
42-
"""
43-
Run a batch of queries
44-
45-
Args:
46-
queries: list of query strings
47-
qids: list of qid strings corresponding to queries
48-
top_k: number of hits to return, defaults to 10
49-
threads: maximum number of threads to use
50-
51-
Returns:
52-
Dict of qid to hits
53-
54-
"""
55-
pass
56-
57-
class PyseriniRetriever(BaseRetriever):
9+
class PyseriniRetriever:
5810
def __init__(self, index_location: str, use_bm25: bool = True, k1: float = float(0.9), b: float = float(0.4)):
5911
"""
6012
Initialize Pyserini retriever

primeqa/pipelines/components/base.py

+26
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,19 @@ class IndexerComponent(Component):
5454
@abstractmethod
5555
def index(self, collection: Union[List[dict], str], *args, **kwargs):
5656
pass
57+
58+
@abstractmethod
59+
def get_engine_type() -> str:
60+
"""
61+
Return this retriever engine type. Must match with the retriever tha will be used to query the index.
62+
63+
Raises:
64+
NotImplementedError:
65+
66+
Returns:
67+
str: engine type
68+
"""
69+
raise NotImplementedError
5770

5871

5972
@dataclass(init=False, repr=False, eq=False)
@@ -91,3 +104,16 @@ def __hash__(self) -> int:
91104
@abstractmethod
92105
def retrieve(self, input_texts: List[str], *args, **kwargs):
93106
pass
107+
108+
@abstractmethod
109+
def get_engine_type() -> str:
110+
"""
111+
Return this retriever engine type. Must match with the indexer used to generate the index.
112+
113+
Raises:
114+
NotImplementedError:
115+
116+
Returns:
117+
str: engine type
118+
"""
119+
raise NotImplementedError

primeqa/pipelines/components/indexer/dense.py

+3
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,6 @@ def index(self, collection: Union[List[dict], str], *args, **kwargs):
139139
collection,
140140
overwrite="overwrite" in kwargs and kwargs["overwrite"],
141141
)
142+
143+
def get_engine_type(self):
144+
return "ColBERT"
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
from typing import Union, List
2-
from dataclasses import dataclass
2+
from dataclasses import dataclass, field
3+
import json
34

45
from primeqa.pipelines.components.base import IndexerComponent
6+
from primeqa.ir.sparse.indexer import PyseriniIndexer
7+
58

69

710
@dataclass
811
class BM25Indexer(IndexerComponent):
912
"""_summary_
1013
1114
Args:
15+
index_root (str): Path to root directory where index to be stored.
16+
index_name (str): Index name.
1217
1318
Important:
1419
1. Each field has metadata property which can carry additional information for other downstream usages.
@@ -17,9 +22,49 @@ class BM25Indexer(IndexerComponent):
1722
b. exclude_from_hash (bool,optional): If set to True, that parameter is not considered while building the hash representation for the object. Defaults to False.
1823
1924
"""
20-
25+
26+
num_workers: int = field(
27+
default=1,
28+
metadata={
29+
"name": "Number of worker threads",
30+
},
31+
)
32+
33+
additional_index_args: str = field(
34+
default='--storePositions --storeDocvectors --storeRaw',
35+
metadata={
36+
"name": "Additional index arguments",
37+
},
38+
)
39+
40+
def __post_init__(self):
41+
self._indexer = None
42+
43+
def __hash__(self) -> int:
44+
return hash(
45+
f"{self.__class__.__name__}::{json.dumps({k: v.default for k, v in self.__class__.__dataclass_fields__.items() if not 'exclude_from_hash' in v.metadata or not v.metadata['exclude_from_hash']}, sort_keys=True)}"
46+
)
47+
2148
def load(self, *args, **kwargs):
22-
pass
49+
self._index_path=f"{self.index_root}/{self.index_name}"
50+
self._indexer = PyseriniIndexer()
2351

2452
def index(self, collection: Union[List[dict], str], *args, **kwargs):
25-
pass
53+
if not isinstance(collection, str):
54+
raise TypeError(
55+
"Pyserini indexer expects path to `documents.tsv` as value for `collection` argument."
56+
)
57+
58+
self._indexer.index_collection(collection = collection, index_path=self._index_path,
59+
fieldnames=None,
60+
overwrite="overwrite" in kwargs and kwargs["overwrite"],
61+
threads=kwargs["num_workers"] if "num_workers" in kwargs else 1,
62+
additional_index_cmd_args=kwargs["additional_index_args"] if "additional_index_args" in kwargs
63+
else '--storePositions --storeDocvectors --storeRaw' )
64+
65+
def get_engine_type(self) -> str:
66+
return "BM25"
67+
68+
69+
70+

primeqa/pipelines/components/retriever/dense.py

+3
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,6 @@ def retrieve(self, input_texts: List[str], *args, **kwargs):
127127
[(result[0], result[-1]) for result in results_per_query]
128128
for results_per_query in ranking_results.data.values()
129129
]
130+
131+
def get_engine_type(self):
132+
return "ColBERT"

primeqa/pipelines/components/retriever/sparse.py

+42-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import List
22
from dataclasses import dataclass, field
3+
import json
34

45
from primeqa.pipelines.components.base import RetrieverComponent
6+
from primeqa.ir.sparse.retriever import PyseriniRetriever
57

68

79
@dataclass
@@ -29,22 +31,55 @@ class BM25Retriever(RetrieverComponent):
2931
"description": "Path to root directory where index is stored",
3032
},
3133
)
32-
index_name: str = field(
34+
35+
max_num_documents: int = field(
36+
default=5,
3337
metadata={
34-
"name": "Index name",
38+
"name": "Maximum number of retrieved documents",
39+
"range": [1, 100, 1],
40+
"api_support": True,
41+
"exclude_from_hash": True,
3542
},
3643
)
37-
max_num_documents: int = field(
38-
default=5,
39-
metadata={"name": "Maximum number of documents", "range": [1, 100, 1]},
44+
45+
num_workers: int = field(
46+
default=1,
47+
metadata={
48+
"name": "Num worker threads",
49+
"range": [1, 100, 1],
50+
"exclude_from_hash": True,
51+
},
4052
)
4153

4254
def __post_init__(self):
4355
# Placeholder variables
56+
self._index_path=f"{self.index_root}/{self.index_name}"
4457
self._searcher = None
58+
59+
def __hash__(self) -> int:
60+
# Step 1: Identify all fields to be included in the hash
61+
hashable_fields = [
62+
k
63+
for k, v in self.__class__.__dataclass_fields__.items()
64+
if not "exclude_from_hash" in v.metadata
65+
or not v.metadata["exclude_from_hash"]
66+
]
67+
68+
# Step 2: Run
69+
return hash(
70+
f"{self.__class__.__name__}::{json.dumps({k: v for k, v in vars(self).items() if k in hashable_fields}, sort_keys=True)}"
71+
)
4572

4673
def load(self, *args, **kwargs):
47-
pass
74+
self._searcher = PyseriniRetriever(self._index_path)
4875

4976
def retrieve(self, input_texts: List[str], *args, **kwargs):
50-
pass
77+
qids = [str(idx) for idx, query in enumerate(input_texts) ]
78+
hits = self._searcher.batch_retrieve(input_texts, qids, topK=self.max_num_documents, threads=self.num_workers)
79+
return [
80+
[(result['doc_id'], result['score']) for result in results_per_query]
81+
for results_per_query in hits.values()
82+
]
83+
84+
def get_engine_type(self):
85+
return "BM25"

primeqa/services/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
ATTR_INDEX_ID = "index_id"
44
ATTR_STATUS = "status"
5+
ATTR_ENGINE_TYPE ="engine_type"
56

67

78
class IndexStatus(str, Enum):

primeqa/services/factories.py

+4
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,23 @@
1212
from primeqa.pipelines.components.reader.extractive import ExtractiveReader
1313

1414
from primeqa.pipelines.components.retriever.dense import ColBERTRetriever
15+
from primeqa.pipelines.components.retriever.sparse import BM25Retriever
1516

1617
from primeqa.pipelines.components.indexer.dense import ColBERTIndexer
18+
from primeqa.pipelines.components.indexer.sparse import BM25Indexer
1719

1820
READERS_REGISTRY = {
1921
ExtractiveReader.__name__: ExtractiveReader,
2022
}
2123

2224
RETRIEVERS_REGISTRY = {
2325
ColBERTRetriever.__name__: ColBERTRetriever,
26+
BM25Retriever.__name__: BM25Retriever,
2427
}
2528

2629
INDEXERS_REGISTRY = {
2730
ColBERTIndexer.__name__: ColBERTIndexer,
31+
BM25Indexer.__name__: BM25Indexer,
2832
}
2933

3034

0 commit comments

Comments
 (0)