Skip to content

Commit

Permalink
Fix ltr searcher with self-built index bug (#1084)
Browse files Browse the repository at this point in the history
+ fix bugs for using self-built index bug in ltr
  • Loading branch information
stephaniewhoo authored Mar 20, 2022
1 parent 317dbba commit eb013ea
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 18 deletions.
2 changes: 1 addition & 1 deletion docs/experiments-ltr-msmarco-passage-reranking.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ wget https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-models/model-ltr-msma
tar -xzvf runs/model-ltr-msmarco-passage-mrr-v1.tar.gz -C runs
```

The following command generates our reranking result:
The following command generates our reranking result with our prebuilt index:

```bash
python -m pyserini.search.lucene.ltr
Expand Down
20 changes: 12 additions & 8 deletions pyserini/search/lucene/ltr/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
"""
Running prediction on candidates
"""
def dev_data_loader(file, format, data, rerank, top=1000):
if (rerank):
def dev_data_loader(file, format, data, rerank, prebuilt, top=1000):
if rerank:
if format == 'tsv':
dev = pd.read_csv(file, sep="\t",
names=['qid', 'pid', 'rank'],
Expand All @@ -56,7 +56,10 @@ def dev_data_loader(file, format, data, rerank, top=1000):
assert dev['rank'].dtype == np.int32
dev = dev[dev['rank']<=top]
else:
bm25search = LuceneSearcher.from_prebuilt_index(args.index)
if prebuilt:
bm25search = LuceneSearcher.from_prebuilt_index(args.index)
else:
bm25search = LuceneSearcher(args.index)
bm25search.set_bm25(0.82, 0.68)
dev_dic = {"qid":[], "pid":[], "rank":[]}
for topic in tqdm(queries.keys()):
Expand Down Expand Up @@ -222,9 +225,9 @@ def output(file, dev_data, format, maxp):
score_tie_counter += 1
score_tie_query.add(qid)
prev_score = t.score
if (maxp):
if maxp:
docid = t.pid.split('#')[0]
if (qid not in results or docid not in results[qid] or t.score > results[qid][docid]):
if qid not in results or docid not in results[qid] or t.score > results[qid][docid]:
results[qid][docid] = t.score
else:
results[qid][t.pid] = t.score
Expand All @@ -235,7 +238,7 @@ def output(file, dev_data, format, maxp):
docid_score = results[qid]
docid_score = sorted(docid_score.items(),key=lambda kv: kv[1], reverse=True)
for docid, score in docid_score:
if (format=='trec'):
if format=='trec':
output_file.write(f"{qid}\tQ0\t{docid}\t{rank}\t{score}\tltr\n")
else:
output_file.write(f"{qid}\t{docid}\t{rank}\n")
Expand All @@ -261,8 +264,9 @@ def output(file, dev_data, format, maxp):
args = parser.parse_args()
queries = query_loader()
print("---------------------loading dev----------------------------------------")
dev, dev_qrel = dev_data_loader(args.input, args.input_format, args.data, args.rerank, args.hits)
searcher = MsmarcoLtrSearcher(args.model, args.ibm_model, args.index, args.data)
prebuilt = args.index == 'msmarco-passage-ltr' or args.index == 'msmarco-doc-per-passage-ltr'
dev, dev_qrel = dev_data_loader(args.input, args.input_format, args.data, args.rerank, prebuilt, args.hits)
searcher = MsmarcoLtrSearcher(args.model, args.ibm_model, args.index, args.data, prebuilt)
searcher.add_fe()
batch_info = searcher.search(dev, queries)
del dev, queries
Expand Down
22 changes: 13 additions & 9 deletions pyserini/search/lucene/ltr/_search_msmarco.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,22 @@
logger = logging.getLogger(__name__)

class MsmarcoLtrSearcher:
def __init__(self, model: str, ibm_model:str, index:str, data: str):
def __init__(self, model: str, ibm_model:str, index:str, data: str, prebuilt: bool):
#msmarco-ltr-passage
self.model = model
self.ibm_model = ibm_model
self.lucene_searcher = LuceneSearcher.from_prebuilt_index(index)
index_directory = os.path.join(get_cache_home(), 'indexes')
if (data == 'passage'):
index_path = os.path.join(index_directory, 'index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3')
if prebuilt:
self.lucene_searcher = LuceneSearcher.from_prebuilt_index(index)
index_directory = os.path.join(get_cache_home(), 'indexes')
if data == 'passage':
index_path = os.path.join(index_directory, 'index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3')
else:
index_path = os.path.join(index_directory, 'index-msmarco-doc-per-passage-ltr-20211031-33e4151.bd60e89041b4ebbabc4bf0cfac608a87')
self.index_reader = IndexReader.from_prebuilt_index(index)
else:
index_path = os.path.join(index_directory, 'index-msmarco-doc-per-passage-ltr-20211031-33e4151.bd60e89041b4ebbabc4bf0cfac608a87')
index_path = index
self.index_reader = IndexReader(index)
self.fe = FeatureExtractor(index_path, max(multiprocessing.cpu_count()//2, 1))
self.index_reader = IndexReader.from_prebuilt_index(index)
self.data = data


Expand Down Expand Up @@ -187,8 +191,8 @@ def batch_extract(self, df, queries, fe):
"query_dict": queries[qid]
}
for t in group.reset_index().itertuples():
if (self.data == 'document'):
if (self.index_reader.doc(t.pid) != None):
if self.data == 'document':
if self.index_reader.doc(t.pid) != None:
task["docIds"].append(t.pid)
task_infos.append((qid, t.pid, t.rel))
else:
Expand Down

0 comments on commit eb013ea

Please sign in to comment.