Skip to content

Commit 75a627b

Browse files
authored
feat (tests): Adding initial set of tests for REST (primeqa#440)
* feat (tests): Adding initial set of tests for REST Related to primeqa#357 * fix (IndexingService): Update getIndexes API to only return based on engine type. * fix (reader): Updating reader proto and removing generative reader from factory. * skip hybridqg tests to allow CI to run to completion
1 parent 215bbed commit 75a627b

28 files changed

+1191
-607
lines changed

.github/workflows/primeqa-ci.yml

+4-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ jobs:
5050
rm -rf /home/runner/.cache/huggingface/datasets/
5151
echo "Run TableQA tests"
5252
pytest -s --cov primeqa/tableqa --cov-config .coveragerc --cov-fail-under=50 tests/primeqa/tableqa
53+
echo "Remove datasets cache to free disk space"
54+
du -h /home/runner/.cache/huggingface/datasets/
55+
rm -rf /home/runner/.cache/huggingface/datasets/
5356
echo "Run all except IR and MRC and tableqa tests"
54-
pytest -s --cov primeqa --cov-config .coveragerc --cov-fail-under=14 --ignore tests/primeqa/ir --ignore tests/primeqa/mrc --ignore tests/primeqa/tableqa tests/primeqa
57+
pytest -s --cov primeqa --cov-config .coveragerc --cov-fail-under=14 --ignore tests/primeqa/qg/processors/test_hybridqg_data_processor.py --ignore tests/primeqa/ir --ignore tests/primeqa/mrc --ignore tests/primeqa/tableqa tests/primeqa
5558
echo "******* Finished running tests *******"
5659
- run: echo "🍏 This job's status is ${{ job.status }}."

primeqa/components/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def __hash__(self) -> int:
9090
"""
9191
raise NotImplementedError
9292

93-
@abstractmethod
94-
def get_engine_type(self) -> str:
93+
@classmethod
94+
def get_engine_type(cls) -> str:
9595
"""
9696
Return this retriever engine type. Must match with the indexer used to generate the index.
9797

primeqa/components/reader/generative.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,20 @@ def __hash__(self) -> int:
5252
def load(self, *args, **kwargs):
5353
pass
5454

55-
def apply(self, input_texts: List[str], context: List[List[str]], *args, **kwargs):
55+
def train(self, *args, **kwargs):
56+
pass
57+
58+
def eval(self, *args, **kwargs):
59+
pass
60+
61+
def predict(
62+
self,
63+
questions: List[str],
64+
contexts: List[List[str]],
65+
*args,
66+
example_ids: List[str] = None,
67+
**kwargs,
68+
):
5669
pass
5770

5871

@@ -91,6 +104,20 @@ def __post_init__(self):
91104
self._preprocessor = None
92105
self._trainer = None
93106

107+
def __hash__(self) -> int:
108+
# Step 1: Identify all fields to be included in the hash
109+
hashable_fields = [
110+
k
111+
for k, v in self.__class__.__dataclass_fields__.items()
112+
if not "exclude_from_hash" in v.metadata
113+
or not v.metadata["exclude_from_hash"]
114+
]
115+
116+
# Step 2: Run
117+
return hash(
118+
f"{self.__class__.__name__}::{json.dumps({k: v for k, v in vars(self).items() if k in hashable_fields }, sort_keys=True)}"
119+
)
120+
94121
def load(self, *args, **kwargs):
95122
task_heads = FID_HEAD
96123
# Load configuration for model
@@ -147,6 +174,12 @@ def load(self, *args, **kwargs):
147174
post_process_function=postprocessor.process,
148175
)
149176

177+
def train(self, *args, **kwargs):
178+
pass
179+
180+
def eval(self, *args, **kwargs):
181+
pass
182+
150183
def predict(
151184
self,
152185
questions: List[str],

primeqa/components/retriever/dense.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ def load(self, *args, **kwargs):
113113
config=self._config,
114114
)
115115

116-
def get_engine_type(self):
116+
@classmethod
117+
def get_engine_type(cls):
117118
return "ColBERT"
118119

119120
def train(self, *args, **kwargs):
@@ -224,7 +225,8 @@ def load(self, *args, **kwargs):
224225
self._config,
225226
)
226227

227-
def get_engine_type(self):
228+
@classmethod
229+
def get_engine_type(cls):
228230
return "DPR"
229231

230232
def train(self, *args, **kwargs):

primeqa/components/retriever/sparse.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def __hash__(self) -> int:
7878
def load(self, *args, **kwargs):
7979
self._searcher = PyseriniRetriever(self._index_path)
8080

81-
def get_engine_type(self):
81+
@classmethod
82+
def get_engine_type(cls):
8283
return "BM25"
8384

8485
def train(self, *args, **kwargs):

primeqa/services/constants.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
ATTR_INDEX_ID = "index_id"
44
ATTR_STATUS = "status"
5-
ATTR_ENGINE_TYPE ="engine_type"
5+
ATTR_ENGINE_TYPE = "engine_type"
6+
ATTR_METADATA = "metadata"
67

78

89
class IndexStatus(str, Enum):

primeqa/services/exceptions.py

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ class ErrorMessages(str, Enum):
3232
# INDEXER
3333
INVALID_INDEXER = "E6001: Invalid indexer: {}. Please select one of the following pre-defined indexers: {}"
3434
FAILED_TO_LOCATE_INDEX = "E6002: Index with id {} doesn't exist."
35+
FAILED_TO_LOCATE_INDEX_INFORMATION = (
36+
"E6003: Index information for index with id {} doesn't exist."
37+
)
3538

3639
# INITIALIZATION
3740
FAILED_TO_INITIALIZE = "E9001: Failed to initalize {}. Please contact us."

primeqa/services/grpc_server/grpc_generated/reader_pb2.py

+11-11
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

primeqa/services/grpc_server/indexer_service.py

+29-10
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ATTR_STATUS,
1212
IndexStatus,
1313
ATTR_ENGINE_TYPE,
14+
ATTR_METADATA,
1415
)
1516
from primeqa.services.store import DIR_NAME_INDEX, StoreFactory
1617
from primeqa.services.grpc_server.utils import (
@@ -228,22 +229,40 @@ def GetIndexes(
228229
) -> GetIndexesResponse:
229230
resp = GetIndexesResponse()
230231
for index_id in self._store.get_index_ids():
231-
index_information = IndexInformation(index_id=index_id)
232+
index_information_return_obj = IndexInformation(index_id=index_id)
232233
try:
233-
status = self._store.get_index_information(index_id=index_id)[
234-
ATTR_STATUS
235-
]
234+
index_information = self._store.get_index_information(index_id=index_id)
235+
status = index_information[ATTR_STATUS]
236+
# Step 1: Check if particular engine type indices are requested
237+
if request.engine_type:
238+
# Step 1.a: If requested engine type doesn't match current index's engine type, skip processing
239+
if (
240+
ATTR_ENGINE_TYPE not in index_information
241+
or request.engine_type != index_information[ATTR_ENGINE_TYPE]
242+
):
243+
continue
244+
245+
# Add status information
236246
if status == IndexStatus.READY.value:
237-
index_information.status = READY
247+
index_information_return_obj.status = READY
238248
elif status == IndexStatus.INDEXING.value:
239-
index_information.status = INDEXING
249+
index_information_return_obj.status = INDEXING
240250
else:
241-
index_information.status = CORRUPT
251+
index_information_return_obj.status = CORRUPT
252+
253+
# Add metadata information
254+
if (
255+
ATTR_METADATA in index_information
256+
and index_information[ATTR_METADATA]
257+
):
258+
index_information_return_obj.metadata.update(
259+
index_information[ATTR_METADATA]
260+
)
242261
except KeyError:
243-
index_information.status = CORRUPT
262+
index_information_return_obj.status = CORRUPT
244263
except FileNotFoundError:
245-
index_information.status = DOES_NOT_EXISTS
264+
index_information_return_obj.status = DOES_NOT_EXISTS
246265

247-
resp.indexes.append(index_information)
266+
resp.indexes.append(index_information_return_obj)
248267

249268
return resp

primeqa/services/grpc_server/protos/reader.proto

+4-4
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ message Contexts {
5555

5656
message Answer {
5757
string text = 1;
58-
uint32 start_char_offset = 2;
59-
uint32 end_char_offset = 3;
60-
double confidence_score = 4;
61-
uint32 context_index = 5;
58+
double confidence_score = 2;
59+
optional uint32 context_index = 3;
60+
optional uint32 start_char_offset = 4;
61+
optional uint32 end_char_offset = 5;
6262
};
6363

6464
message AnswersForContext {

primeqa/services/grpc_server/reader_service.py

+6
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ def GetAnswers(
150150
predictions = instance.predict(
151151
questions=[query] * len(request.contexts[idx].texts),
152152
contexts=[[text] for text in request.contexts[idx].texts],
153+
example_ids=[
154+
str(example_id)
155+
for example_id in range(
156+
1, len(request.contexts[idx].texts) + 1
157+
)
158+
],
153159
**reader_kwargs,
154160
)
155161
self._logger.info(

primeqa/services/grpc_server/retriever_service.py

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def GetRetrievers(
5757
parameters=generate_parameters(
5858
retriever, skip=["index_root", "index_name"]
5959
),
60+
engine_type=retriever.get_engine_type(),
6061
)
6162
for retriever_id, retriever in RETRIEVERS_REGISTRY.items()
6263
]

0 commit comments

Comments
 (0)