From 7532c7f3661fcd6ec6fa845918643ced307dd9b5 Mon Sep 17 00:00:00 2001 From: lintool Date: Mon, 7 Oct 2024 23:25:53 -0400 Subject: [PATCH 1/6] Refactoring pyserini.encode and pyserini.search.faiss --- integrations-optional/dense/test_ance.py | 2 +- integrations-optional/dense/test_dpr.py | 2 +- .../dense/test_tct_colbert.py | 2 +- pyserini/encode/__init__.py | 1 + pyserini/encode/_ance.py | 71 ++++-- pyserini/encode/_base.py | 43 +++- pyserini/encode/_bpr.py | 67 ++++++ pyserini/encode/_dpr.py | 29 ++- pyserini/encode/_tct_colbert.py | 47 ++-- pyserini/search/faiss/__init__.py | 8 +- pyserini/search/faiss/__main__.py | 12 +- pyserini/search/faiss/_model.py | 77 ------- pyserini/search/faiss/_searcher.py | 206 +----------------- tests-optional/test_load_encoded_queries.py | 2 +- 14 files changed, 230 insertions(+), 339 deletions(-) create mode 100644 pyserini/encode/_bpr.py delete mode 100644 pyserini/search/faiss/_model.py diff --git a/integrations-optional/dense/test_ance.py b/integrations-optional/dense/test_ance.py index 7dcb5174d..bce4066d9 100644 --- a/integrations-optional/dense/test_ance.py +++ b/integrations-optional/dense/test_ance.py @@ -21,8 +21,8 @@ import unittest from integrations.utils import clean_files, run_command, parse_score_qa, parse_score_msmarco +from pyserini.encode import QueryEncoder from pyserini.search import get_topics -from pyserini.search.faiss._searcher import QueryEncoder class TestAnce(unittest.TestCase): diff --git a/integrations-optional/dense/test_dpr.py b/integrations-optional/dense/test_dpr.py index f06900251..56c011604 100644 --- a/integrations-optional/dense/test_dpr.py +++ b/integrations-optional/dense/test_dpr.py @@ -22,8 +22,8 @@ import unittest from integrations.utils import clean_files, run_command, parse_score_qa +from pyserini.encode import QueryEncoder from pyserini.search import get_topics -from pyserini.search.faiss._searcher import QueryEncoder class TestDpr(unittest.TestCase): diff --git a/integrations-optional/dense/test_tct_colbert.py b/integrations-optional/dense/test_tct_colbert.py index f103e1a76..5984aea10 100644 --- a/integrations-optional/dense/test_tct_colbert.py +++ b/integrations-optional/dense/test_tct_colbert.py @@ -21,8 +21,8 @@ import unittest from integrations.utils import clean_files, run_command, parse_score +from pyserini.encode import QueryEncoder from pyserini.search import get_topics -from pyserini.search.faiss._searcher import QueryEncoder class TestTctColBert(unittest.TestCase): diff --git a/pyserini/encode/__init__.py b/pyserini/encode/__init__.py index 841e34800..3e1a992a3 100644 --- a/pyserini/encode/__init__.py +++ b/pyserini/encode/__init__.py @@ -21,6 +21,7 @@ from ._aggretriever import AggretrieverDocumentEncoder, AggretrieverQueryEncoder from ._ance import AnceEncoder, AnceDocumentEncoder, AnceQueryEncoder from ._auto import AutoQueryEncoder, AutoDocumentEncoder +from ._bpr import BprQueryEncoder from ._cached_data import CachedDataQueryEncoder from ._cosdpr import CosDprEncoder, CosDprDocumentEncoder, CosDprQueryEncoder from ._dpr import DprDocumentEncoder, DprQueryEncoder diff --git a/pyserini/encode/_ance.py b/pyserini/encode/_ance.py index 3976689dc..f51f4e16d 100644 --- a/pyserini/encode/_ance.py +++ b/pyserini/encode/_ance.py @@ -14,10 +14,10 @@ # limitations under the License. # -from typing import Optional +from typing import Optional, List import torch -from transformers import PreTrainedModel, RobertaConfig, RobertaModel, RobertaTokenizer +from transformers import PreTrainedModel, RobertaConfig, RobertaModel, RobertaTokenizer, requires_backends from pyserini.encode import DocumentEncoder, QueryEncoder @@ -30,6 +30,7 @@ class AnceEncoder(PreTrainedModel): _keys_to_ignore_on_load_unexpected = [r'pooler', r'classifier'] def __init__(self, config: RobertaConfig): + requires_backends(self, 'torch') super().__init__(config) self.config = config self.roberta = RobertaModel(config) @@ -55,11 +56,7 @@ def init_weights(self): self.embeddingHead.apply(self._init_weights) self.norm.apply(self._init_weights) - def forward( - self, - input_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ): + def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): input_shape = input_ids.size() device = input_ids.device if attention_mask is None: @@ -98,22 +95,60 @@ def encode(self, texts, titles=None, max_length=256, **kwargs): class AnceQueryEncoder(QueryEncoder): + def __init__(self, encoder_dir: str = None, tokenizer_name: str = None, + encoded_query_dir: str = None, device: str = 'cpu', **kwargs): + super().__init__(encoded_query_dir) + if encoder_dir: + self.device = device + self.model = AnceEncoder.from_pretrained(encoder_dir) + self.model.to(self.device) + self.tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name or encoder_dir) + self.has_model = True + self.tokenizer.do_lower_case = True + if (not self.has_model) and (not self.has_encoded_query): + raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one') + + def encode(self, query: str): + if self.has_model: + inputs = self.tokenizer( + [query], + max_length=64, + padding='longest', + truncation=True, + add_special_tokens=True, + return_tensors='pt' + ) + inputs.to(self.device) + embeddings = self.model(inputs["input_ids"]).detach().cpu().numpy() + return embeddings.flatten() + else: + return super().encode(query) + + def prf_encode(self, query: str): + if self.has_model: + inputs = self.tokenizer( + [query], + max_length=512, + padding='longest', + truncation=True, + add_special_tokens=False, + return_tensors='pt' + ) + inputs.to(self.device) + embeddings = self.model(inputs["input_ids"]).detach().cpu().numpy() + return embeddings.flatten() + else: + return super().encode(query) - def __init__(self, model_name: str, tokenizer_name: str = None, device: str = 'cpu'): - self.device = device - self.model = AnceEncoder.from_pretrained(model_name) - self.model.to(self.device) - self.tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name or model_name) - - def encode(self, query: str, **kwargs): + def prf_batch_encode(self, query: List[str]): inputs = self.tokenizer( - [query], - max_length=64, + query, + max_length=512, padding='longest', truncation=True, - add_special_tokens=True, + add_special_tokens=False, return_tensors='pt' ) inputs.to(self.device) embeddings = self.model(inputs["input_ids"]).detach().cpu().numpy() - return embeddings.flatten() + return embeddings diff --git a/pyserini/encode/_base.py b/pyserini/encode/_base.py index 47457ace4..ed0d1aa2a 100644 --- a/pyserini/encode/_base.py +++ b/pyserini/encode/_base.py @@ -18,9 +18,12 @@ import os import numpy as np +import pandas as pd import torch from tqdm import tqdm +from pyserini.util import download_encoded_queries + class DocumentEncoder: def encode(self, texts, **kwargs): @@ -36,8 +39,44 @@ def _mean_pooling(last_hidden_state, attention_mask): class QueryEncoder: - def encode(self, text, **kwargs): - pass + def __init__(self, encoded_query_dir: str = None): + self.has_model = False + self.has_encoded_query = False + if encoded_query_dir: + self.embedding = self._load_embeddings(encoded_query_dir) + self.has_encoded_query = True + + def encode(self, query: str): + return self.embedding[query] + + @classmethod + def load_encoded_queries(cls, encoded_query_name: str): + """Build a query encoder from a pre-encoded query; download the encoded queries if necessary. + + Parameters + ---------- + encoded_query_name : str + pre encoded query name. + + Returns + ------- + QueryEncoder + Encoder built from the pre encoded queries. + """ + print(f'Attempting to initialize pre-encoded queries {encoded_query_name}.') + try: + query_dir = download_encoded_queries(encoded_query_name) + except ValueError as e: + print(str(e)) + return None + + print(f'Initializing {encoded_query_name}...') + return cls(encoded_query_dir=query_dir) + + @staticmethod + def _load_embeddings(encoded_query_dir): + df = pd.read_pickle(os.path.join(encoded_query_dir, 'embedding.pkl')) + return dict(zip(df['text'].tolist(), df['embedding'].tolist())) class JsonlCollectionIterator: diff --git a/pyserini/encode/_bpr.py b/pyserini/encode/_bpr.py new file mode 100644 index 000000000..978d71dbd --- /dev/null +++ b/pyserini/encode/_bpr.py @@ -0,0 +1,67 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +import pandas as pd +import torch +from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer + +from pyserini.encode import QueryEncoder + + +class BprQueryEncoder(QueryEncoder): + def __init__(self, encoder_dir: str = None, tokenizer_name: str = None, + encoded_query_dir: str = None, device: str = 'cpu', **kwargs): + self.has_model = False + self.has_encoded_query = False + + if encoded_query_dir: + self.embedding = self._load_embeddings(encoded_query_dir) + self.has_encoded_query = True + + if encoder_dir: + self.device = device + self.model = DPRQuestionEncoder.from_pretrained(encoder_dir) + self.model.to(self.device) + self.tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(tokenizer_name or encoder_dir) + self.has_model = True + + if (not self.has_model) and (not self.has_encoded_query): + raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one') + + def encode(self, query: str): + if self.has_model: + input_ids = self.tokenizer(query, return_tensors='pt') + input_ids.to(self.device) + embeddings = self.model(input_ids["input_ids"]).pooler_output.detach().cpu() + dense_embeddings = embeddings.numpy() + sparse_embeddings = self.convert_to_binary_code(embeddings).numpy() + return {'dense': dense_embeddings.flatten(), 'sparse': sparse_embeddings.flatten()} + else: + return super().encode(query) + + def convert_to_binary_code(self, input_repr: torch.Tensor): + return input_repr.new_ones(input_repr.size()).masked_fill_(input_repr < 0, -1.0) + + @staticmethod + def _load_embeddings(encoded_query_dir): + df = pd.read_pickle(os.path.join(encoded_query_dir, 'embedding.pkl')) + ret = {} + for text, dense, sparse in zip(df['text'].tolist(), df['dense_embedding'].tolist(), + df['sparse_embedding'].tolist()): + ret[text] = {'dense': dense, 'sparse': sparse} + return ret diff --git a/pyserini/encode/_dpr.py b/pyserini/encode/_dpr.py index 9e19a387c..9d4930897 100644 --- a/pyserini/encode/_dpr.py +++ b/pyserini/encode/_dpr.py @@ -51,14 +51,23 @@ def encode(self, texts, titles=None, max_length=256, **kwargs): class DprQueryEncoder(QueryEncoder): - def __init__(self, model_name: str, tokenizer_name: str = None, device: str = 'cpu'): - self.device = device - self.model = DPRQuestionEncoder.from_pretrained(model_name) - self.model.to(self.device) - self.tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(tokenizer_name or model_name) + def __init__(self, encoder_dir: str = None, tokenizer_name: str = None, + encoded_query_dir: str = None, device: str = 'cpu', **kwargs): + super().__init__(encoded_query_dir) + if encoder_dir: + self.device = device + self.model = DPRQuestionEncoder.from_pretrained(encoder_dir) + self.model.to(self.device) + self.tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(tokenizer_name or encoder_dir) + self.has_model = True + if (not self.has_model) and (not self.has_encoded_query): + raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one') - def encode(self, query: str, **kwargs): - input_ids = self.tokenizer(query, return_tensors='pt') - input_ids.to(self.device) - embeddings = self.model(input_ids["input_ids"]).pooler_output.detach().cpu().numpy() - return embeddings.flatten() + def encode(self, query: str): + if self.has_model: + input_ids = self.tokenizer(query, return_tensors='pt') + input_ids.to(self.device) + embeddings = self.model(input_ids["input_ids"]).pooler_output.detach().cpu().numpy() + return embeddings.flatten() + else: + return super().encode(query) diff --git a/pyserini/encode/_tct_colbert.py b/pyserini/encode/_tct_colbert.py index f7ef51bdc..1f2d16549 100644 --- a/pyserini/encode/_tct_colbert.py +++ b/pyserini/encode/_tct_colbert.py @@ -16,8 +16,10 @@ import numpy as np import torch + if torch.cuda.is_available(): from torch.cuda.amp import autocast + from transformers import BertModel, BertTokenizer, BertTokenizerFast from pyserini.encode import DocumentEncoder, QueryEncoder @@ -70,22 +72,31 @@ def encode(self, texts, titles=None, fp16=False, max_length=512, **kwargs): class TctColBertQueryEncoder(QueryEncoder): - def __init__(self, model_name: str, tokenizer_name: str = None, device: str = 'cpu'): - self.device = device - self.model = BertModel.from_pretrained(model_name) - self.model.to(self.device) - self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name or model_name) + def __init__(self, encoder_dir: str = None, tokenizer_name: str = None, + encoded_query_dir: str = None, device: str = 'cpu', **kwargs): + super().__init__(encoded_query_dir) + if encoder_dir: + self.device = device + self.model = BertModel.from_pretrained(encoder_dir) + self.model.to(self.device) + self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name or encoder_dir) + self.has_model = True + if (not self.has_model) and (not self.has_encoded_query): + raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one') - def encode(self, query: str, **kwargs): - max_length = 36 # hardcode for now - inputs = self.tokenizer( - '[CLS] [Q] ' + query + '[MASK]' * max_length, - max_length=max_length, - truncation=True, - add_special_tokens=False, - return_tensors='pt' - ) - inputs.to(self.device) - outputs = self.model(**inputs) - embeddings = outputs.last_hidden_state.detach().cpu().numpy() - return np.average(embeddings[:, 4:, :], axis=-2).flatten() + def encode(self, query: str): + if self.has_model: + max_length = 36 # hardcode for now + inputs = self.tokenizer( + '[CLS] [Q] ' + query + '[MASK]' * max_length, + max_length=max_length, + truncation=True, + add_special_tokens=False, + return_tensors='pt' + ) + inputs.to(self.device) + outputs = self.model(**inputs) + embeddings = outputs.last_hidden_state.detach().cpu().numpy() + return np.average(embeddings[:, 4:, :], axis=-2).flatten() + else: + return super().encode(query) diff --git a/pyserini/search/faiss/__init__.py b/pyserini/search/faiss/__init__.py index a95c54181..5cc156131 100644 --- a/pyserini/search/faiss/__init__.py +++ b/pyserini/search/faiss/__init__.py @@ -14,10 +14,4 @@ # limitations under the License. # -from ._searcher import FaissSearcher, DenseSearchResult - -# from ._prf import DenseVectorAveragePrf, DenseVectorRocchioPrf, DenseVectorAncePrf, PRFDenseSearchResult -# from ._searcher import DenseSearchResult, FaissSearcher, BinaryDenseSearcher, QueryEncoder, \ -# DprQueryEncoder, BprQueryEncoder, DkrrDprQueryEncoder, TctColBertQueryEncoder, AnceQueryEncoder, \ -# AggretrieverQueryEncoder, OpenAIQueryEncoder, \ -# AutoQueryEncoder, ClipQueryEncoder +from ._searcher import FaissSearcher, BinaryDenseSearcher, DenseSearchResult diff --git a/pyserini/search/faiss/__main__.py b/pyserini/search/faiss/__main__.py index d067038ee..34f9f6133 100644 --- a/pyserini/search/faiss/__main__.py +++ b/pyserini/search/faiss/__main__.py @@ -20,16 +20,22 @@ import numpy as np from tqdm import tqdm +from pyserini.encode import QueryEncoder + +from pyserini.encode import AnceQueryEncoder +from pyserini.encode import BprQueryEncoder from pyserini.encode import CosDprQueryEncoder +from pyserini.encode import DprQueryEncoder +from pyserini.encode import TctColBertQueryEncoder + from pyserini.encode._pca import PcaEncoder from pyserini.output_writer import get_output_writer, OutputFormat from pyserini.query_iterator import get_query_iterator, TopicsFormat from pyserini.search.faiss._searcher import (AutoQueryEncoder, AggretrieverQueryEncoder, OpenAIQueryEncoder, - QueryEncoder, AnceQueryEncoder, BinaryDenseSearcher, BprQueryEncoder, - DprQueryEncoder, DkrrDprQueryEncoder, ClipQueryEncoder, TctColBertQueryEncoder) + DkrrDprQueryEncoder, ClipQueryEncoder) from pyserini.search.lucene import LuceneSearcher from ._prf import DenseVectorAveragePrf, DenseVectorRocchioPrf, DenseVectorAncePrf -from ._searcher import FaissSearcher +from pyserini.search.faiss import FaissSearcher, BinaryDenseSearcher # Fixes this error: "OMP: Error #15: Initializing libomp.a, but found libomp.dylib already initialized." # https://stackoverflow.com/questions/53014306/error-15-initializing-libiomp5-dylib-but-found-libiomp5-dylib-already-initial diff --git a/pyserini/search/faiss/_model.py b/pyserini/search/faiss/_model.py deleted file mode 100644 index 3c6e3e5fc..000000000 --- a/pyserini/search/faiss/_model.py +++ /dev/null @@ -1,77 +0,0 @@ -# -# Pyserini: Reproducible IR research with sparse and dense representations -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing import Optional - -from transformers import PreTrainedModel, RobertaConfig, RobertaModel -from transformers.file_utils import is_torch_available, requires_backends - -if is_torch_available(): - import torch - - -class AnceEncoder(PreTrainedModel): - config_class = RobertaConfig - base_model_prefix = 'ance_encoder' - load_tf_weights = None - _keys_to_ignore_on_load_missing = [r'position_ids'] - _keys_to_ignore_on_load_unexpected = [r'pooler', r'classifier'] - - def __init__(self, config: RobertaConfig): - requires_backends(self, 'torch') - super().__init__(config) - self.config = config - self.roberta = RobertaModel(config) - self.embeddingHead = torch.nn.Linear(config.hidden_size, 768) - self.norm = torch.nn.LayerNorm(768) - self.init_weights() - - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights - def _init_weights(self, module): - """ Initialize the weights """ - if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, torch.nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, torch.nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - def init_weights(self): - self.roberta.init_weights() - self.embeddingHead.apply(self._init_weights) - self.norm.apply(self._init_weights) - - def forward( - self, - input_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ): - input_shape = input_ids.size() - device = input_ids.device - if attention_mask is None: - attention_mask = ( - torch.ones(input_shape, device=device) - if input_ids is None - else (input_ids != self.roberta.config.pad_token_id) - ) - outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask) - sequence_output = outputs.last_hidden_state - pooled_output = sequence_output[:, 0, :] - pooled_output = self.norm(self.embeddingHead(pooled_output)) - return pooled_output diff --git a/pyserini/search/faiss/_searcher.py b/pyserini/search/faiss/_searcher.py index f24e59932..2dc5f932b 100644 --- a/pyserini/search/faiss/_searcher.py +++ b/pyserini/search/faiss/_searcher.py @@ -32,12 +32,17 @@ DPRQuestionEncoder, DPRQuestionEncoderTokenizer, RobertaTokenizer) from transformers.file_utils import is_faiss_available, requires_backends +from pyserini.encode import QueryEncoder +from pyserini.encode import AnceQueryEncoder +from pyserini.encode import BprQueryEncoder +from pyserini.encode import DprQueryEncoder +from pyserini.encode import TctColBertQueryEncoder + from pyserini.encode._clip import ClipEncoder from pyserini.index import Document from pyserini.search.lucene import LuceneSearcher from pyserini.util import (download_encoded_queries, download_prebuilt_index, get_dense_indexes_info, get_sparse_index) -from ._model import AnceEncoder from ...encode._aggretriever import BERTAggretrieverEncoder, DistlBERTAggretrieverEncoder from ._prf import DenseVectorAveragePrf, DenseVectorRocchioPrf, DenseVectorAncePrf, PRFDenseSearchResult @@ -46,46 +51,6 @@ import faiss -class QueryEncoder: - def __init__(self, encoded_query_dir: str = None): - self.has_model = False - self.has_encoded_query = False - if encoded_query_dir: - self.embedding = self._load_embeddings(encoded_query_dir) - self.has_encoded_query = True - - def encode(self, query: str): - return self.embedding[query] - - @classmethod - def load_encoded_queries(cls, encoded_query_name: str): - """Build a query encoder from a pre-encoded query; download the encoded queries if necessary. - - Parameters - ---------- - encoded_query_name : str - pre encoded query name. - - Returns - ------- - QueryEncoder - Encoder built from the pre encoded queries. - """ - print(f'Attempting to initialize pre-encoded queries {encoded_query_name}.') - try: - query_dir = download_encoded_queries(encoded_query_name) - except ValueError as e: - print(str(e)) - return None - - print(f'Initializing {encoded_query_name}...') - return cls(encoded_query_dir=query_dir) - - @staticmethod - def _load_embeddings(encoded_query_dir): - df = pd.read_pickle(os.path.join(encoded_query_dir, 'embedding.pkl')) - return dict(zip(df['text'].tolist(), df['embedding'].tolist())) - class ClipQueryEncoder(QueryEncoder): """Encodes queries using a CLIP model, supporting both images and texts.""" def __init__(self, @@ -141,105 +106,6 @@ def encode(self, query: str, max_length: int=32): return embeddings.flatten() else: return super().encode(query) - - -class TctColBertQueryEncoder(QueryEncoder): - - def __init__(self, encoder_dir: str = None, tokenizer_name: str = None, - encoded_query_dir: str = None, device: str = 'cpu', **kwargs): - super().__init__(encoded_query_dir) - if encoder_dir: - self.device = device - self.model = BertModel.from_pretrained(encoder_dir) - self.model.to(self.device) - self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name or encoder_dir) - self.has_model = True - if (not self.has_model) and (not self.has_encoded_query): - raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one') - - def encode(self, query: str): - if self.has_model: - max_length = 36 # hardcode for now - inputs = self.tokenizer( - '[CLS] [Q] ' + query + '[MASK]' * max_length, - max_length=max_length, - truncation=True, - add_special_tokens=False, - return_tensors='pt' - ) - inputs.to(self.device) - outputs = self.model(**inputs) - embeddings = outputs.last_hidden_state.detach().cpu().numpy() - return np.average(embeddings[:, 4:, :], axis=-2).flatten() - else: - return super().encode(query) - - -class DprQueryEncoder(QueryEncoder): - - def __init__(self, encoder_dir: str = None, tokenizer_name: str = None, - encoded_query_dir: str = None, device: str = 'cpu', **kwargs): - super().__init__(encoded_query_dir) - if encoder_dir: - self.device = device - self.model = DPRQuestionEncoder.from_pretrained(encoder_dir) - self.model.to(self.device) - self.tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(tokenizer_name or encoder_dir) - self.has_model = True - if (not self.has_model) and (not self.has_encoded_query): - raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one') - - def encode(self, query: str): - if self.has_model: - input_ids = self.tokenizer(query, return_tensors='pt') - input_ids.to(self.device) - embeddings = self.model(input_ids["input_ids"]).pooler_output.detach().cpu().numpy() - return embeddings.flatten() - else: - return super().encode(query) - - -class BprQueryEncoder(QueryEncoder): - - def __init__(self, encoder_dir: str = None, tokenizer_name: str = None, - encoded_query_dir: str = None, device: str = 'cpu', **kwargs): - self.has_model = False - self.has_encoded_query = False - if encoded_query_dir: - self.embedding = self._load_embeddings(encoded_query_dir) - self.has_encoded_query = True - - if encoder_dir: - self.device = device - self.model = DPRQuestionEncoder.from_pretrained(encoder_dir) - self.model.to(self.device) - self.tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(tokenizer_name or encoder_dir) - self.has_model = True - if (not self.has_model) and (not self.has_encoded_query): - raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one') - - def encode(self, query: str): - if self.has_model: - input_ids = self.tokenizer(query, return_tensors='pt') - input_ids.to(self.device) - embeddings = self.model(input_ids["input_ids"]).pooler_output.detach().cpu() - dense_embeddings = embeddings.numpy() - sparse_embeddings = self.convert_to_binary_code(embeddings).numpy() - return {'dense': dense_embeddings.flatten(), 'sparse': sparse_embeddings.flatten()} - else: - return super().encode(query) - - def convert_to_binary_code(self, input_repr: torch.Tensor): - return input_repr.new_ones(input_repr.size()).masked_fill_(input_repr < 0, -1.0) - - @staticmethod - def _load_embeddings(encoded_query_dir): - df = pd.read_pickle(os.path.join(encoded_query_dir, 'embedding.pkl')) - ret = {} - for text, dense, sparse in zip(df['text'].tolist(), df['dense_embedding'].tolist(), - df['sparse_embedding'].tolist()): - ret[text] = {'dense': dense, 'sparse': sparse} - return ret class DkrrDprQueryEncoder(QueryEncoder): @@ -274,66 +140,6 @@ def encode(self, query: str): return super().encode(query) -class AnceQueryEncoder(QueryEncoder): - - def __init__(self, encoder_dir: str = None, tokenizer_name: str = None, - encoded_query_dir: str = None, device: str = 'cpu', **kwargs): - super().__init__(encoded_query_dir) - if encoder_dir: - self.device = device - self.model = AnceEncoder.from_pretrained(encoder_dir) - self.model.to(self.device) - self.tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name or encoder_dir) - self.has_model = True - self.tokenizer.do_lower_case = True - if (not self.has_model) and (not self.has_encoded_query): - raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one') - - def encode(self, query: str): - if self.has_model: - inputs = self.tokenizer( - [query], - max_length=64, - padding='longest', - truncation=True, - add_special_tokens=True, - return_tensors='pt' - ) - inputs.to(self.device) - embeddings = self.model(inputs["input_ids"]).detach().cpu().numpy() - return embeddings.flatten() - else: - return super().encode(query) - - def prf_encode(self, query: str): - if self.has_model: - inputs = self.tokenizer( - [query], - max_length=512, - padding='longest', - truncation=True, - add_special_tokens=False, - return_tensors='pt' - ) - inputs.to(self.device) - embeddings = self.model(inputs["input_ids"]).detach().cpu().numpy() - return embeddings.flatten() - else: - return super().encode(query) - - def prf_batch_encode(self, query: List[str]): - inputs = self.tokenizer( - query, - max_length=512, - padding='longest', - truncation=True, - add_special_tokens=False, - return_tensors='pt' - ) - inputs.to(self.device) - embeddings = self.model(inputs["input_ids"]).detach().cpu().numpy() - return embeddings - class OpenAIQueryEncoder(QueryEncoder): from pyserini.encode._openai import retry_with_delay diff --git a/tests-optional/test_load_encoded_queries.py b/tests-optional/test_load_encoded_queries.py index 09cde0f28..baf49991a 100644 --- a/tests-optional/test_load_encoded_queries.py +++ b/tests-optional/test_load_encoded_queries.py @@ -19,7 +19,7 @@ import unittest from pyserini.search import get_topics -from pyserini.search.faiss._searcher import QueryEncoder +from pyserini.encode import QueryEncoder class TestLoadEncodedQueries(unittest.TestCase): From 7aba21c57019c37888ac2fa05a2a00ec2788eaa3 Mon Sep 17 00:00:00 2001 From: lintool Date: Tue, 8 Oct 2024 10:11:33 -0400 Subject: [PATCH 2/6] Clean up everything except for auto. --- pyserini/encode/__init__.py | 1 + pyserini/encode/_aggretriever.py | 64 ++++++++++--------- pyserini/encode/_dkrr.py | 52 ++++++++++++++++ pyserini/search/faiss/__init__.py | 3 +- pyserini/search/faiss/__main__.py | 20 +++--- pyserini/search/faiss/_prf.py | 38 ++++++------ pyserini/search/faiss/_searcher.py | 98 ++++-------------------------- 7 files changed, 126 insertions(+), 150 deletions(-) create mode 100644 pyserini/encode/_dkrr.py diff --git a/pyserini/encode/__init__.py b/pyserini/encode/__init__.py index 3e1a992a3..5849a37ad 100644 --- a/pyserini/encode/__init__.py +++ b/pyserini/encode/__init__.py @@ -24,6 +24,7 @@ from ._bpr import BprQueryEncoder from ._cached_data import CachedDataQueryEncoder from ._cosdpr import CosDprEncoder, CosDprDocumentEncoder, CosDprQueryEncoder +from ._dkrr import DkrrDprQueryEncoder from ._dpr import DprDocumentEncoder, DprQueryEncoder from ._openai import OpenAIDocumentEncoder, OpenAIQueryEncoder, OPENAI_API_RETRY_DELAY from ._slim import SlimQueryEncoder diff --git a/pyserini/encode/_aggretriever.py b/pyserini/encode/_aggretriever.py index 37268bca5..c16c510d6 100644 --- a/pyserini/encode/_aggretriever.py +++ b/pyserini/encode/_aggretriever.py @@ -27,7 +27,8 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer, PreTrainedModel from pyserini.encode import DocumentEncoder, QueryEncoder -class BERTAggretrieverEncoder(PreTrainedModel): + +class BertAggretrieverEncoder(PreTrainedModel): config_class = BertConfig base_model_prefix = 'encoder' load_tf_weights = None @@ -120,7 +121,7 @@ def forward( return torch.cat((semantic_reps, lexical_reps), -1) -class DistlBERTAggretrieverEncoder(BERTAggretrieverEncoder): +class DistlBertAggretrieverEncoder(BertAggretrieverEncoder): config_class = DistilBertConfig base_model_prefix = 'encoder' load_tf_weights = None @@ -130,9 +131,9 @@ class AggretrieverDocumentEncoder(DocumentEncoder): def __init__(self, model_name: str, tokenizer_name=None, device='cuda:0'): self.device = device if 'distilbert' in model_name.lower(): - self.model = DistlBERTAggretrieverEncoder.from_pretrained(model_name) + self.model = DistlBertAggretrieverEncoder.from_pretrained(model_name) else: - self.model = BERTAggretrieverEncoder.from_pretrained(model_name) + self.model = BertAggretrieverEncoder.from_pretrained(model_name) self.model.to(self.device) self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name) @@ -160,30 +161,33 @@ def encode(self, texts, titles=None, fp16=False, max_length=512, **kwargs): class AggretrieverQueryEncoder(QueryEncoder): - def __init__(self, model_name: str, tokenizer_name=None, device='cuda:0'): - self.device = device - if 'distilbert' in model_name.lower(): - self.model = DistlBERTAggretrieverEncoder.from_pretrained(model_name) - else: - self.model = BERTAggretrieverEncoder.from_pretrained(model_name) - self.model.to(self.device) - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name) - - def encode(self, texts, fp16=False, max_length=32, **kwargs): - texts = [text for text in texts] - inputs = self.tokenizer( - texts, - max_length=max_length, - padding="longest", - truncation=True, - add_special_tokens=True, - return_tensors='pt' - ) - inputs.to(self.device) - if fp16: - with autocast(): - with torch.no_grad(): - outputs = self.model(**inputs) - else: + def __init__(self, encoder_dir: str = None, tokenizer_name: str = None, + encoded_query_dir: str = None, device: str = 'cpu', **kwargs): + if encoder_dir: + self.device = device + if 'distilbert' in encoder_dir.lower(): + self.model = DistlBertAggretrieverEncoder.from_pretrained(encoder_dir) + else: + self.model = BertAggretrieverEncoder.from_pretrained(encoder_dir) + self.model.to(self.device) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or encoder_dir) + self.has_model = True + if (not self.has_model) and (not self.has_encoded_query): + raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one') + + def encode(self, query: str, max_length: int=32): + if self.has_model: + inputs = self.tokenizer( + query, + max_length=max_length, + padding="longest", + truncation=True, + add_special_tokens=True, + return_tensors='pt' + ) + inputs.to(self.device) outputs = self.model(**inputs) - return outputs.detach().cpu().numpy() \ No newline at end of file + embeddings = outputs.detach().cpu().numpy() + return embeddings.flatten() + else: + return super().encode(query) diff --git a/pyserini/encode/_dkrr.py b/pyserini/encode/_dkrr.py new file mode 100644 index 000000000..1bf82707e --- /dev/null +++ b/pyserini/encode/_dkrr.py @@ -0,0 +1,52 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import torch +from transformers import (BertModel, BertTokenizerFast) + +from pyserini.encode import QueryEncoder + + +class DkrrDprQueryEncoder(QueryEncoder): + def __init__(self, encoder_dir: str = None, encoded_query_dir: str = None, device: str = 'cpu', + prefix: str = "question:", **kwargs): + super().__init__(encoded_query_dir) + self.device = device + self.model = BertModel.from_pretrained(encoder_dir) + self.model.to(self.device) + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.has_model = True + self.prefix = prefix + + @staticmethod + def _mean_pooling(model_output, attention_mask): + model_output = model_output[0].masked_fill(attention_mask[:, :, None] == 0, 0.) + model_output = torch.sum(model_output, dim=1) / torch.clamp(torch.sum(attention_mask, dim=1), min=1e-9)[:, None] + return model_output.flatten() + + def encode(self, query: str): + if self.has_model: + if self.prefix: + query = f'{self.prefix} {query}' + inputs = self.tokenizer(query, return_tensors='pt', max_length=40, padding="max_length") + inputs.to(self.device) + outputs = self.model(input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"]) + embeddings = self._mean_pooling(outputs, inputs['attention_mask']).detach().cpu().numpy() + return embeddings.flatten() + else: + return super().encode(query) diff --git a/pyserini/search/faiss/__init__.py b/pyserini/search/faiss/__init__.py index 5cc156131..ad9df6f75 100644 --- a/pyserini/search/faiss/__init__.py +++ b/pyserini/search/faiss/__init__.py @@ -14,4 +14,5 @@ # limitations under the License. # -from ._searcher import FaissSearcher, BinaryDenseSearcher, DenseSearchResult +from ._searcher import FaissSearcher, BinaryDenseFaissSearcher, DenseSearchResult +from ._prf import DenseVectorAveragePrf, DenseVectorRocchioPrf, DenseVectorAncePrf, PrfDenseSearchResult diff --git a/pyserini/search/faiss/__main__.py b/pyserini/search/faiss/__main__.py index 34f9f6133..13245b958 100644 --- a/pyserini/search/faiss/__main__.py +++ b/pyserini/search/faiss/__main__.py @@ -21,21 +21,15 @@ from tqdm import tqdm from pyserini.encode import QueryEncoder - -from pyserini.encode import AnceQueryEncoder -from pyserini.encode import BprQueryEncoder -from pyserini.encode import CosDprQueryEncoder -from pyserini.encode import DprQueryEncoder -from pyserini.encode import TctColBertQueryEncoder - +from pyserini.encode import AggretrieverQueryEncoder, AnceQueryEncoder, BprQueryEncoder, CosDprQueryEncoder, \ + DkrrDprQueryEncoder, DprQueryEncoder, TctColBertQueryEncoder from pyserini.encode._pca import PcaEncoder from pyserini.output_writer import get_output_writer, OutputFormat from pyserini.query_iterator import get_query_iterator, TopicsFormat -from pyserini.search.faiss._searcher import (AutoQueryEncoder, AggretrieverQueryEncoder, OpenAIQueryEncoder, - DkrrDprQueryEncoder, ClipQueryEncoder) +from pyserini.search.faiss import FaissSearcher, BinaryDenseFaissSearcher +from pyserini.search.faiss import DenseVectorAveragePrf, DenseVectorRocchioPrf, DenseVectorAncePrf +from pyserini.search.faiss._searcher import AutoQueryEncoder, OpenAIQueryEncoder, ClipQueryEncoder from pyserini.search.lucene import LuceneSearcher -from ._prf import DenseVectorAveragePrf, DenseVectorRocchioPrf, DenseVectorAncePrf -from pyserini.search.faiss import FaissSearcher, BinaryDenseSearcher # Fixes this error: "OMP: Error #15: Initializing libomp.a, but found libomp.dylib already initialized." # https://stackoverflow.com/questions/53014306/error-15-initializing-libiomp5-dylib-but-found-libiomp5-dylib-already-initial @@ -215,14 +209,14 @@ def init_query_encoder(encoder, encoder_class, tokenizer_name, topics_name, enco # create searcher from index directory if args.searcher.lower() == 'bpr': kwargs = dict(binary_k=args.binary_hits, rerank=args.rerank) - searcher = BinaryDenseSearcher(args.index, query_encoder) + searcher = BinaryDenseFaissSearcher(args.index, query_encoder) else: searcher = FaissSearcher(args.index, query_encoder) else: # create searcher from prebuilt index name if args.searcher.lower() == 'bpr': kwargs = dict(binary_k=args.binary_hits, rerank=args.rerank) - searcher = BinaryDenseSearcher.from_prebuilt_index(args.index, query_encoder) + searcher = BinaryDenseFaissSearcher.from_prebuilt_index(args.index, query_encoder) else: searcher = FaissSearcher.from_prebuilt_index(args.index, query_encoder) diff --git a/pyserini/search/faiss/_prf.py b/pyserini/search/faiss/_prf.py index 3efbd48f7..976c4aa71 100644 --- a/pyserini/search/faiss/_prf.py +++ b/pyserini/search/faiss/_prf.py @@ -25,7 +25,7 @@ @dataclass -class PRFDenseSearchResult: +class PrfDenseSearchResult: docid: str score: float vectors: [float] @@ -43,15 +43,15 @@ def get_batch_prf_q_emb(self, **kwargs): class DenseVectorAveragePrf(DenseVectorPrf): - def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDenseSearchResult] = None): + def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PrfDenseSearchResult] = None): """Perform Average PRF with Dense Vectors Parameters ---------- emb_qs : np.ndarray Query embedding - prf_candidates : List[PRFDenseSearchResult] - List of PRFDenseSearchResult, contains document embeddings. + prf_candidates : List[PrfDenseSearchResult] + List of PrfDenseSearchResult, contains document embeddings. Returns ------- @@ -63,7 +63,7 @@ def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDense return new_emb_qs def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None, - prf_candidates: Dict[str, List[PRFDenseSearchResult]] = None): + prf_candidates: Dict[str, List[PrfDenseSearchResult]] = None): """Perform Average PRF with Dense Vectors Parameters @@ -72,8 +72,8 @@ def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = List of topic ids. emb_qs : np.ndarray Query embeddings - prf_candidates : List[PRFDenseSearchResult] - List of PRFDenseSearchResult, contains document embeddings. + prf_candidates : List[PrfDenseSearchResult] + List of PrfDenseSearchResult, contains document embeddings. Returns ------- @@ -113,15 +113,15 @@ def __init__(self, alpha: float, beta: float, gamma: float, topk: int, bottomk: self.topk = topk self.bottomk = bottomk - def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDenseSearchResult] = None): + def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PrfDenseSearchResult] = None): """Perform Rocchio PRF with Dense Vectors Parameters ---------- emb_qs : np.ndarray query embedding - prf_candidates : List[PRFDenseSearchResult] - List of PRFDenseSearchResult, contains document embeddings. + prf_candidates : List[PrfDenseSearchResult] + List of PrfDenseSearchResult, contains document embeddings. Returns ------- @@ -139,7 +139,7 @@ def get_prf_q_emb(self, emb_qs: np.ndarray = None, prf_candidates: List[PRFDense return new_emb_q def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = None, - prf_candidates: Dict[str, List[PRFDenseSearchResult]] = None): + prf_candidates: Dict[str, List[PrfDenseSearchResult]] = None): """Perform Rocchio PRF with Dense Vectors Parameters @@ -148,8 +148,8 @@ def get_batch_prf_q_emb(self, topic_ids: List[str] = None, emb_qs: np.ndarray = List of topic ids. emb_qs : np.ndarray Query embeddings - prf_candidates : List[PRFDenseSearchResult] - List of PRFDenseSearchResult, contains document embeddings. + prf_candidates : List[PrfDenseSearchResult] + List of PrfDenseSearchResult, contains document embeddings. Returns ------- @@ -179,15 +179,15 @@ def __init__(self, encoder: AnceQueryEncoder, sparse_searcher: LuceneSearcher): self.encoder = encoder self.sparse_searcher = sparse_searcher - def get_prf_q_emb(self, query: str = None, prf_candidates: List[PRFDenseSearchResult] = None): + def get_prf_q_emb(self, query: str = None, prf_candidates: List[PrfDenseSearchResult] = None): """Perform single ANCE-PRF with Dense Vectors Parameters ---------- query : str query text - prf_candidates : List[PRFDenseSearchResult] - List of PRFDenseSearchResult, contains document embeddings. + prf_candidates : List[PrfDenseSearchResult] + List of PrfDenseSearchResult, contains document embeddings. Returns ------- @@ -204,7 +204,7 @@ def get_prf_q_emb(self, query: str = None, prf_candidates: List[PRFDenseSearchRe return emb_q def get_batch_prf_q_emb(self, topics: List[str], topic_ids: List[str], - prf_candidates: Dict[str, List[PRFDenseSearchResult]]) -> np.ndarray: + prf_candidates: Dict[str, List[PrfDenseSearchResult]]) -> np.ndarray: """Perform batch ANCE-PRF with Dense Vectors Parameters @@ -213,8 +213,8 @@ def get_batch_prf_q_emb(self, topics: List[str], topic_ids: List[str], List of query texts. topic_ids: List[str] List of topic ids. - prf_candidates : List[PRFDenseSearchResult] - List of PRFDenseSearchResult, contains document embeddings. + prf_candidates : List[PrfDenseSearchResult] + List of PrfDenseSearchResult, contains document embeddings. Returns ------- diff --git a/pyserini/search/faiss/_searcher.py b/pyserini/search/faiss/_searcher.py index 2dc5f932b..a0348a6a8 100644 --- a/pyserini/search/faiss/_searcher.py +++ b/pyserini/search/faiss/_searcher.py @@ -23,32 +23,21 @@ from dataclasses import dataclass from typing import Dict, List, Union, Optional, Tuple +import faiss import numpy as np import openai -import pandas as pd import tiktoken import torch -from transformers import (AutoModel, AutoTokenizer, BertModel, BertTokenizer, BertTokenizerFast, - DPRQuestionEncoder, DPRQuestionEncoderTokenizer, RobertaTokenizer) -from transformers.file_utils import is_faiss_available, requires_backends +from transformers import AutoModel, AutoTokenizer +from transformers.file_utils import requires_backends from pyserini.encode import QueryEncoder -from pyserini.encode import AnceQueryEncoder -from pyserini.encode import BprQueryEncoder -from pyserini.encode import DprQueryEncoder -from pyserini.encode import TctColBertQueryEncoder - +from pyserini.encode import AnceQueryEncoder, BprQueryEncoder, DprQueryEncoder, TctColBertQueryEncoder from pyserini.encode._clip import ClipEncoder from pyserini.index import Document from pyserini.search.lucene import LuceneSearcher -from pyserini.util import (download_encoded_queries, download_prebuilt_index, - get_dense_indexes_info, get_sparse_index) -from ...encode._aggretriever import BERTAggretrieverEncoder, DistlBERTAggretrieverEncoder - -from ._prf import DenseVectorAveragePrf, DenseVectorRocchioPrf, DenseVectorAncePrf, PRFDenseSearchResult - -#if is_faiss_available(): -import faiss +from pyserini.util import download_prebuilt_index, get_dense_indexes_info, get_sparse_index +from pyserini.search.faiss import PrfDenseSearchResult class ClipQueryEncoder(QueryEncoder): @@ -75,71 +64,6 @@ def encode(self, query: str): return self.encoder.encode(query).flatten() -class AggretrieverQueryEncoder(QueryEncoder): - def __init__(self, encoder_dir: str = None, tokenizer_name: str = None, - encoded_query_dir: str = None, device: str = 'cpu', **kwargs): - if encoder_dir: - self.device = device - if 'distilbert' in encoder_dir.lower(): - self.model = DistlBERTAggretrieverEncoder.from_pretrained(encoder_dir) - else: - self.model = BERTAggretrieverEncoder.from_pretrained(encoder_dir) - self.model.to(self.device) - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or encoder_dir) - self.has_model = True - if (not self.has_model) and (not self.has_encoded_query): - raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one') - - def encode(self, query: str, max_length: int=32): - if self.has_model: - inputs = self.tokenizer( - query, - max_length=max_length, - padding="longest", - truncation=True, - add_special_tokens=True, - return_tensors='pt' - ) - inputs.to(self.device) - outputs = self.model(**inputs) - embeddings = outputs.detach().cpu().numpy() - return embeddings.flatten() - else: - return super().encode(query) - - -class DkrrDprQueryEncoder(QueryEncoder): - - def __init__(self, encoder_dir: str = None, encoded_query_dir: str = None, device: str = 'cpu', - prefix: str = "question:", **kwargs): - super().__init__(encoded_query_dir) - self.device = device - self.model = BertModel.from_pretrained(encoder_dir) - self.model.to(self.device) - self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") - self.has_model = True - self.prefix = prefix - - @staticmethod - def _mean_pooling(model_output, attention_mask): - model_output = model_output[0].masked_fill(attention_mask[:, :, None] == 0, 0.) - model_output = torch.sum(model_output, dim=1) / torch.clamp(torch.sum(attention_mask, dim=1), min=1e-9)[:, None] - return model_output.flatten() - - def encode(self, query: str): - if self.has_model: - if self.prefix: - query = f'{self.prefix} {query}' - inputs = self.tokenizer(query, return_tensors='pt', max_length=40, padding="max_length") - inputs.to(self.device) - outputs = self.model(input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"]) - embeddings = self._mean_pooling(outputs, inputs['attention_mask']).detach().cpu().numpy() - return embeddings.flatten() - else: - return super().encode(query) - - class OpenAIQueryEncoder(QueryEncoder): from pyserini.encode._openai import retry_with_delay @@ -290,7 +214,7 @@ def list_prebuilt_indexes(): get_dense_indexes_info() def search(self, query: Union[str, np.ndarray], k: int = 10, threads: int = 1, remove_dups: bool = False, return_vector: bool = False) \ - -> Union[List[DenseSearchResult], Tuple[np.ndarray, List[PRFDenseSearchResult]]]: + -> Union[List[DenseSearchResult], Tuple[np.ndarray, List[PrfDenseSearchResult]]]: """Search the collection. Parameters @@ -323,7 +247,7 @@ def search(self, query: Union[str, np.ndarray], k: int = 10, threads: int = 1, r vectors = vectors[0] distances = distances.flat indexes = indexes.flat - return emb_q, [PRFDenseSearchResult(self.docids[idx], score, vector) + return emb_q, [PrfDenseSearchResult(self.docids[idx], score, vector) for score, idx, vector in zip(distances, indexes, vectors) if idx != -1] else: distances, indexes = self.index.search(emb_q, k) @@ -342,7 +266,7 @@ def search(self, query: Union[str, np.ndarray], k: int = 10, threads: int = 1, r def batch_search(self, queries: Union[List[str], np.ndarray], q_ids: List[str], k: int = 10, threads: int = 1, return_vector: bool = False) \ - -> Union[Dict[str, List[DenseSearchResult]], Tuple[np.ndarray, Dict[str, List[PRFDenseSearchResult]]]]: + -> Union[Dict[str, List[DenseSearchResult]], Tuple[np.ndarray, Dict[str, List[PrfDenseSearchResult]]]]: """ Parameters @@ -374,7 +298,7 @@ def batch_search(self, queries: Union[List[str], np.ndarray], q_ids: List[str], faiss.omp_set_num_threads(threads) if return_vector: D, I, V = self.index.search_and_reconstruct(q_embs, k) - return q_embs, {key: [PRFDenseSearchResult(self.docids[idx], score, vector) + return q_embs, {key: [PrfDenseSearchResult(self.docids[idx], score, vector) for score, idx, vector in zip(distances, indexes, vectors) if idx != -1] for key, distances, indexes, vectors in zip(q_ids, D, I, V)} else: @@ -433,7 +357,7 @@ def set_hnsw_ef_search(self, ef_search: int): self.index.hnsw.efSearch = ef_search -class BinaryDenseSearcher(FaissSearcher): +class BinaryDenseFaissSearcher(FaissSearcher): """Simple Searcher for binary-dense representation Parameters From 0f6c3a403c2befce2fb9195a3acd63b3f817f4a8 Mon Sep 17 00:00:00 2001 From: lintool Date: Tue, 8 Oct 2024 12:28:24 -0400 Subject: [PATCH 3/6] Reconciled auto. --- pyserini/encode/_auto.py | 68 ++++++++++++++++++------------ pyserini/search/faiss/__main__.py | 8 ++-- pyserini/search/faiss/_searcher.py | 57 +------------------------ 3 files changed, 46 insertions(+), 87 deletions(-) diff --git a/pyserini/encode/_auto.py b/pyserini/encode/_auto.py index 7ee256e5a..3d2e24251 100644 --- a/pyserini/encode/_auto.py +++ b/pyserini/encode/_auto.py @@ -70,33 +70,45 @@ def encode(self, texts, titles=None, max_length=256, add_sep=False, **kwargs): class AutoQueryEncoder(QueryEncoder): - def __init__(self, encoder_dir: str, tokenizer_name: str = None, device: str = 'cpu', - pooling: str = 'cls', l2_norm: bool = False, prefix=None): - self.device = device - self.model = AutoModel.from_pretrained(encoder_dir) - self.model.to(self.device) - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or encoder_dir) - self.pooling = pooling - self.l2_norm = l2_norm - self.prefix = prefix + def __init__(self, encoder_dir: str = None, tokenizer_name: str = None, + encoded_query_dir: str = None, device: str = 'cpu', + pooling: str = 'cls', l2_norm: bool = False, prefix=None, **kwargs): + super().__init__(encoded_query_dir) + if encoder_dir: + self.device = device + self.model = AutoModel.from_pretrained(encoder_dir) + self.model.to(self.device) + try: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or encoder_dir) + except: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or encoder_dir, use_fast=False) + self.has_model = True + self.pooling = pooling + self.l2_norm = l2_norm + self.prefix = prefix + if (not self.has_model) and (not self.has_encoded_query): + raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one') - def encode(self, query: str, **kwargs): - if self.prefix: - query = f'{self.prefix} {query}' - inputs = self.tokenizer( - query, - add_special_tokens=True, - return_tensors='pt', - truncation='only_first', - padding='longest', - return_token_type_ids=False, - ) - inputs.to(self.device) - outputs = self.model(**inputs)[0].detach().cpu().numpy() - if self.pooling == "mean": - embeddings = np.average(outputs, axis=-2) + def encode(self, query: str): + if self.has_model: + if self.prefix: + query = f'{self.prefix} {query}' + inputs = self.tokenizer( + query, + add_special_tokens=True, + return_tensors='pt', + truncation='only_first', + padding='longest', + return_token_type_ids=False, + ) + inputs.to(self.device) + outputs = self.model(**inputs)[0].detach().cpu().numpy() + if self.pooling == "mean": + embeddings = np.average(outputs, axis=-2) + else: + embeddings = outputs[:, 0, :] + if self.l2_norm: + embeddings = normalize(embeddings, norm='l2') + return embeddings.flatten() else: - embeddings = outputs[:, 0, :] - if self.l2_norm: - embeddings = normalize(embeddings, norm='l2') - return embeddings.flatten() + return super().encode(query) diff --git a/pyserini/search/faiss/__main__.py b/pyserini/search/faiss/__main__.py index 13245b958..2798977f6 100644 --- a/pyserini/search/faiss/__main__.py +++ b/pyserini/search/faiss/__main__.py @@ -20,16 +20,16 @@ import numpy as np from tqdm import tqdm -from pyserini.encode import QueryEncoder +from pyserini.encode import QueryEncoder, AutoQueryEncoder from pyserini.encode import AggretrieverQueryEncoder, AnceQueryEncoder, BprQueryEncoder, CosDprQueryEncoder, \ DkrrDprQueryEncoder, DprQueryEncoder, TctColBertQueryEncoder from pyserini.encode._pca import PcaEncoder from pyserini.output_writer import get_output_writer, OutputFormat from pyserini.query_iterator import get_query_iterator, TopicsFormat -from pyserini.search.faiss import FaissSearcher, BinaryDenseFaissSearcher -from pyserini.search.faiss import DenseVectorAveragePrf, DenseVectorRocchioPrf, DenseVectorAncePrf -from pyserini.search.faiss._searcher import AutoQueryEncoder, OpenAIQueryEncoder, ClipQueryEncoder from pyserini.search.lucene import LuceneSearcher +from ._prf import DenseVectorAveragePrf, DenseVectorRocchioPrf, DenseVectorAncePrf +from ._searcher import OpenAIQueryEncoder, ClipQueryEncoder +from ._searcher import FaissSearcher, BinaryDenseFaissSearcher # Fixes this error: "OMP: Error #15: Initializing libomp.a, but found libomp.dylib already initialized." # https://stackoverflow.com/questions/53014306/error-15-initializing-libiomp5-dylib-but-found-libiomp5-dylib-already-initial diff --git a/pyserini/search/faiss/_searcher.py b/pyserini/search/faiss/_searcher.py index a0348a6a8..8f4918d84 100644 --- a/pyserini/search/faiss/_searcher.py +++ b/pyserini/search/faiss/_searcher.py @@ -27,17 +27,15 @@ import numpy as np import openai import tiktoken -import torch -from transformers import AutoModel, AutoTokenizer from transformers.file_utils import requires_backends -from pyserini.encode import QueryEncoder +from pyserini.encode import QueryEncoder, AutoQueryEncoder from pyserini.encode import AnceQueryEncoder, BprQueryEncoder, DprQueryEncoder, TctColBertQueryEncoder from pyserini.encode._clip import ClipEncoder from pyserini.index import Document +from pyserini.search.faiss._prf import PrfDenseSearchResult from pyserini.search.lucene import LuceneSearcher from pyserini.util import download_prebuilt_index, get_dense_indexes_info, get_sparse_index -from pyserini.search.faiss import PrfDenseSearchResult class ClipQueryEncoder(QueryEncoder): @@ -92,57 +90,6 @@ def encode(self, query: str, **kwargs): else: return super().encode(query) -class AutoQueryEncoder(QueryEncoder): - - def __init__(self, encoder_dir: str = None, tokenizer_name: str = None, - encoded_query_dir: str = None, device: str = 'cpu', - pooling: str = 'cls', l2_norm: bool = False, **kwargs): - super().__init__(encoded_query_dir) - if encoder_dir: - self.device = device - self.model = AutoModel.from_pretrained(encoder_dir) - self.model.to(self.device) - try: - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or encoder_dir) - except: - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or encoder_dir, use_fast=False) - self.has_model = True - self.pooling = pooling - self.l2_norm = l2_norm - if (not self.has_model) and (not self.has_encoded_query): - raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one') - - @staticmethod - def _mean_pooling(model_output, attention_mask): - token_embeddings = model_output[0] # First element of model_output contains all token embeddings - input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() - sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) - sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) - return sum_embeddings / sum_mask - - def encode(self, query: str): - if self.has_model: - inputs = self.tokenizer( - query, - add_special_tokens=True, - return_tensors='pt', - truncation='only_first', - padding='longest', - return_token_type_ids=False, - ) - - inputs.to(self.device) - outputs = self.model(**inputs) - if self.pooling == "mean": - embeddings = self._mean_pooling(outputs, inputs['attention_mask']).detach().cpu().numpy() - else: - embeddings = outputs[0][:, 0, :].detach().cpu().numpy() - if self.l2_norm: - faiss.normalize_L2(embeddings) - return embeddings.flatten() - else: - return super().encode(query) - @dataclass class DenseSearchResult: From ee1faab6933c853c17b331dc0f00f7a3fb987551 Mon Sep 17 00:00:00 2001 From: lintool Date: Tue, 8 Oct 2024 20:38:34 -0400 Subject: [PATCH 4/6] Refactoring sync. --- .../sparse/test_lucenesearcher_check_ltr_msmarco_document.py | 2 +- .../sparse/test_lucenesearcher_check_ltr_msmarco_passage.py | 4 +--- integrations/clprf/test_trec_covid_r5.py | 5 +++-- integrations/lucenesearcher_anserini_checker.py | 2 +- integrations/sparse/test_lucenesearcher_check_irst.py | 2 +- integrations/sparse/test_search_pretokenized.py | 2 +- integrations/utils.py | 2 +- pyserini/analysis/_base.py | 1 - pyserini/demo/dpr.py | 3 ++- pyserini/demo/msmarco.py | 3 ++- 10 files changed, 13 insertions(+), 13 deletions(-) diff --git a/integrations-optional/sparse/test_lucenesearcher_check_ltr_msmarco_document.py b/integrations-optional/sparse/test_lucenesearcher_check_ltr_msmarco_document.py index 73da5b4af..ec42a5b06 100644 --- a/integrations-optional/sparse/test_lucenesearcher_check_ltr_msmarco_document.py +++ b/integrations-optional/sparse/test_lucenesearcher_check_ltr_msmarco_document.py @@ -23,7 +23,7 @@ class TestLtrMsmarcoDocument(unittest.TestCase): def test_reranking(self): - if(os.path.isdir('ltr_test')): + if os.path.isdir('ltr_test'): rmtree('ltr_test') os.mkdir('ltr_test') inp = 'run.msmarco-pass-doc.bm25.txt' diff --git a/integrations-optional/sparse/test_lucenesearcher_check_ltr_msmarco_passage.py b/integrations-optional/sparse/test_lucenesearcher_check_ltr_msmarco_passage.py index def199306..84fc51cac 100644 --- a/integrations-optional/sparse/test_lucenesearcher_check_ltr_msmarco_passage.py +++ b/integrations-optional/sparse/test_lucenesearcher_check_ltr_msmarco_passage.py @@ -20,12 +20,10 @@ import unittest from shutil import rmtree -from pyserini.search.lucene import LuceneSearcher - class TestLtrMsmarcoPassage(unittest.TestCase): def test_reranking(self): - if(os.path.isdir('ltr_test')): + if os.path.isdir('ltr_test'): rmtree('ltr_test') os.mkdir('ltr_test') inp = 'run.msmarco-passage.bm25tuned.txt' diff --git a/integrations/clprf/test_trec_covid_r5.py b/integrations/clprf/test_trec_covid_r5.py index 92c1332bd..d1accd759 100644 --- a/integrations/clprf/test_trec_covid_r5.py +++ b/integrations/clprf/test_trec_covid_r5.py @@ -14,13 +14,14 @@ # limitations under the License. # +import gzip +import json import os import re import shutil import unittest -import json -import gzip from random import randint + from pyserini.util import download_url, download_prebuilt_index diff --git a/integrations/lucenesearcher_anserini_checker.py b/integrations/lucenesearcher_anserini_checker.py index 159fa5bb4..439527eca 100644 --- a/integrations/lucenesearcher_anserini_checker.py +++ b/integrations/lucenesearcher_anserini_checker.py @@ -18,8 +18,8 @@ import os from typing import List -from pyserini.util import get_cache_home from pyserini.prebuilt_index_info import TF_INDEX_INFO +from pyserini.util import get_cache_home class LuceneSearcherAnseriniMatchChecker: diff --git a/integrations/sparse/test_lucenesearcher_check_irst.py b/integrations/sparse/test_lucenesearcher_check_irst.py index cd2d9bdd0..a1eecb829 100644 --- a/integrations/sparse/test_lucenesearcher_check_irst.py +++ b/integrations/sparse/test_lucenesearcher_check_irst.py @@ -16,8 +16,8 @@ import os import unittest -from shutil import rmtree from random import randint +from shutil import rmtree from integrations.utils import run_command, parse_score diff --git a/integrations/sparse/test_search_pretokenized.py b/integrations/sparse/test_search_pretokenized.py index 3c2dafa34..757a431b1 100644 --- a/integrations/sparse/test_search_pretokenized.py +++ b/integrations/sparse/test_search_pretokenized.py @@ -17,8 +17,8 @@ import os import shutil import unittest - from random import randint + from integrations.lucenesearcher_score_checker import LuceneSearcherScoreChecker diff --git a/integrations/utils.py b/integrations/utils.py index de88ee7c0..90cc0850c 100644 --- a/integrations/utils.py +++ b/integrations/utils.py @@ -15,8 +15,8 @@ # import os -import subprocess import shutil +import subprocess def clean_files(files): diff --git a/pyserini/analysis/_base.py b/pyserini/analysis/_base.py index 2c7b4d1e6..2b4a42b32 100644 --- a/pyserini/analysis/_base.py +++ b/pyserini/analysis/_base.py @@ -47,7 +47,6 @@ # Wrappers around Anserini classes JAnalyzerUtils = autoclass('io.anserini.analysis.AnalyzerUtils') -JDefaultEnglishAnalyzer = autoclass('io.anserini.analysis.DefaultEnglishAnalyzer') JTweetAnalyzer = autoclass('io.anserini.analysis.TweetAnalyzer') JHuggingFaceTokenizerAnalyzer = autoclass('io.anserini.analysis.HuggingFaceTokenizerAnalyzer') diff --git a/pyserini/demo/dpr.py b/pyserini/demo/dpr.py index e008fffc1..9a607c7c0 100644 --- a/pyserini/demo/dpr.py +++ b/pyserini/demo/dpr.py @@ -18,8 +18,9 @@ import json import random +from pyserini.encode import DprQueryEncoder from pyserini.search import get_topics -from pyserini.search.faiss import FaissSearcher, DprQueryEncoder +from pyserini.search.faiss import FaissSearcher from pyserini.search.hybrid import HybridSearcher from pyserini.search.lucene import LuceneSearcher diff --git a/pyserini/demo/msmarco.py b/pyserini/demo/msmarco.py index 4482da97a..ead370a9b 100644 --- a/pyserini/demo/msmarco.py +++ b/pyserini/demo/msmarco.py @@ -18,8 +18,9 @@ import json import random +from pyserini.encode import AnceQueryEncoder, TctColBertQueryEncoder from pyserini.search import get_topics -from pyserini.search.faiss import FaissSearcher, TctColBertQueryEncoder, AnceQueryEncoder +from pyserini.search.faiss import FaissSearcher from pyserini.search.hybrid import HybridSearcher from pyserini.search.lucene import LuceneSearcher From dacbd0b31a5b6c5fac40b7950c09190b84d5819c Mon Sep 17 00:00:00 2001 From: lintool Date: Tue, 8 Oct 2024 20:38:10 -0700 Subject: [PATCH 5/6] Refactored tests. --- tests-optional/test_encoder.py | 27 ----- tests-optional/test_load_encoded_queries.py | 106 ------------------ tests/test_encoder.py | 114 -------------------- tests/test_encoder_model_ance.py | 42 ++++++++ tests/test_encoder_model_distilbert_kd.py | 42 ++++++++ tests/test_encoder_model_distilbert_tasb.py | 42 ++++++++ tests/test_encoder_model_dpr.py | 40 +++++++ tests/test_encoder_model_sbert.py | 32 ++++++ tests/test_encoder_model_tct.py | 66 ++++++++++++ tests/test_encoder_model_unicoil.py | 85 +++++++++++++++ 10 files changed, 349 insertions(+), 247 deletions(-) delete mode 100644 tests-optional/test_load_encoded_queries.py delete mode 100644 tests/test_encoder.py create mode 100644 tests/test_encoder_model_ance.py create mode 100644 tests/test_encoder_model_distilbert_kd.py create mode 100644 tests/test_encoder_model_distilbert_tasb.py create mode 100644 tests/test_encoder_model_dpr.py create mode 100644 tests/test_encoder_model_sbert.py create mode 100644 tests/test_encoder_model_tct.py create mode 100644 tests/test_encoder_model_unicoil.py diff --git a/tests-optional/test_encoder.py b/tests-optional/test_encoder.py index b24b1477b..8907db941 100644 --- a/tests-optional/test_encoder.py +++ b/tests-optional/test_encoder.py @@ -25,13 +25,10 @@ import faiss -from pyserini.encode import TctColBertDocumentEncoder, DprDocumentEncoder, UniCoilDocumentEncoder from pyserini.encode._clip import ClipDocumentEncoder from pyserini.search.lucene import LuceneImpactSearcher -## We need to de-dup wrt tests/test_encoder - class TestEncode(unittest.TestCase): @classmethod def setUpClass(cls): @@ -62,30 +59,6 @@ def assertIsFile(path): if not pl.Path(path).resolve().is_file(): raise AssertionError("File does not exist: %s" % str(path)) - def test_dpr_encoder(self): - encoder = DprDocumentEncoder('facebook/dpr-ctx_encoder-multiset-base', device='cpu') - vectors = encoder.encode(self.texts[:3]) - self.assertAlmostEqual(vectors[0][0], -0.59793323, places=4) - self.assertAlmostEqual(vectors[0][-1], -0.13036962, places=4) - self.assertAlmostEqual(vectors[2][0], -0.3044764, places=4) - self.assertAlmostEqual(vectors[2][-1], 0.1516793, places=4) - - def test_tct_colbert_encoder(self): - encoder = TctColBertDocumentEncoder('castorini/tct_colbert-msmarco', device='cpu') - vectors = encoder.encode(self.texts[:3]) - self.assertAlmostEqual(vectors[0][0], -0.01649557, places=4) - self.assertAlmostEqual(vectors[0][-1], -0.05648308, places=4) - self.assertAlmostEqual(vectors[2][0], -0.10293338, places=4) - self.assertAlmostEqual(vectors[2][-1], 0.05549275, places=4) - - def test_unicoil_encoder(self): - encoder = UniCoilDocumentEncoder('castorini/unicoil-msmarco-passage', device='cpu') - vectors = encoder.encode(self.texts[:3]) - self.assertAlmostEqual(vectors[0]['generation'], 2.2441017627716064, places=4) - self.assertAlmostEqual(vectors[0]['normal'], 2.4618067741394043, places=4) - self.assertAlmostEqual(vectors[2]['rounding'], 3.9474332332611084, places=4) - self.assertAlmostEqual(vectors[2]['commercial'], 3.288801670074463, places=4) - def test_clip_encoder(self): encoder = ClipDocumentEncoder('openai/clip-vit-base-patch32', device='cpu') vectors = encoder.encode(self.texts[:3]) diff --git a/tests-optional/test_load_encoded_queries.py b/tests-optional/test_load_encoded_queries.py deleted file mode 100644 index baf49991a..000000000 --- a/tests-optional/test_load_encoded_queries.py +++ /dev/null @@ -1,106 +0,0 @@ -# -# Pyserini: Reproducible IR research with sparse and dense representations -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Integration tests for DistilBERT KD.""" - -import unittest - -from pyserini.search import get_topics -from pyserini.encode import QueryEncoder - - -class TestLoadEncodedQueries(unittest.TestCase): - def test_ance_encoded_queries(self): - encoded = QueryEncoder.load_encoded_queries('ance-msmarco-passage-dev-subset') - topics = get_topics('msmarco-passage-dev-subset') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - encoded = QueryEncoder.load_encoded_queries('ance-dl19-passage') - topics = get_topics('dl19-passage') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - encoded = QueryEncoder.load_encoded_queries('ance-dl20') - topics = get_topics('dl20') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - def test_distilbert_kd_encoded_queries(self): - encoded = QueryEncoder.load_encoded_queries('distilbert_kd-msmarco-passage-dev-subset') - topics = get_topics('msmarco-passage-dev-subset') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - encoded = QueryEncoder.load_encoded_queries('distilbert_kd-dl19-passage') - topics = get_topics('dl19-passage') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - encoded = QueryEncoder.load_encoded_queries('distilbert_kd-dl20') - topics = get_topics('dl20') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - def test_distilbert_kd_tas_b_encoded_queries(self): - encoded = QueryEncoder.load_encoded_queries('distilbert_tas_b-msmarco-passage-dev-subset') - topics = get_topics('msmarco-passage-dev-subset') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - encoded = QueryEncoder.load_encoded_queries('distilbert_tas_b-dl19-passage') - topics = get_topics('dl19-passage') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - encoded = QueryEncoder.load_encoded_queries('distilbert_tas_b-dl20') - topics = get_topics('dl20') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - def test_msmarco_doc_tct_colbert_encoded_queries(self): - encoded = QueryEncoder.load_encoded_queries('tct_colbert-msmarco-doc-dev') - topics = get_topics('msmarco-doc-dev') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - def test_msmarco_passage_tct_colbert_v2_encoded_queries(self): - encoded = QueryEncoder.load_encoded_queries('tct_colbert-v2-msmarco-passage-dev-subset') - topics = get_topics('msmarco-passage-dev-subset') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - def test_msmarco_passage_tct_colbert_v2_hn_encoded_queries(self): - encoded = QueryEncoder.load_encoded_queries('tct_colbert-v2-hn-msmarco-passage-dev-subset') - topics = get_topics('msmarco-passage-dev-subset') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - def test_msmarco_passage_tct_colbert_v2_hnp_encoded_queries(self): - encoded = QueryEncoder.load_encoded_queries('tct_colbert-v2-hnp-msmarco-passage-dev-subset') - topics = get_topics('msmarco-passage-dev-subset') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - def test_msmarco_passage_sbert_encoded_queries(self): - encoded = QueryEncoder.load_encoded_queries('sbert-msmarco-passage-dev-subset') - topics = get_topics('msmarco-passage-dev-subset') - for t in topics: - self.assertTrue(topics[t]['title'] in encoded.embedding) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_encoder.py b/tests/test_encoder.py deleted file mode 100644 index 852d75b37..000000000 --- a/tests/test_encoder.py +++ /dev/null @@ -1,114 +0,0 @@ -# -# Pyserini: Reproducible IR research with sparse and dense representations -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import json -import os -import pathlib as pl -import shutil -import tarfile -import unittest -from random import randint -from urllib.request import urlretrieve - -from pyserini.encode import TctColBertDocumentEncoder, DprDocumentEncoder, UniCoilDocumentEncoder -from pyserini.search.lucene import LuceneImpactSearcher - - -class TestEncode(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.docids = [] - cls.texts = [] - cls.test_file = 'tests/resources/simple_cacm_corpus.json' - - with open(cls.test_file) as f: - for line in f: - line = json.loads(line) - cls.docids.append(line['id']) - cls.texts.append(line['contents']) - - # LuceneImpactSearcher requires a pre-built index to be initialized - r = randint(0, 10000000) - cls.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene9-index.cacm.tar.gz' - cls.tarball_name = f'lucene-index.cacm-{r}.tar.gz' - cls.index_dir = f'index-{r}/' - - urlretrieve(cls.collection_url, cls.tarball_name) - - tarball = tarfile.open(cls.tarball_name) - tarball.extractall(cls.index_dir) - tarball.close() - - @staticmethod - def assertIsFile(path): - if not pl.Path(path).resolve().is_file(): - raise AssertionError("File does not exist: %s" % str(path)) - - def test_dpr_encoder(self): - encoder = DprDocumentEncoder('facebook/dpr-ctx_encoder-multiset-base', device='cpu') - vectors = encoder.encode(self.texts[:3]) - self.assertAlmostEqual(vectors[0][0], -0.59793323, places=4) - self.assertAlmostEqual(vectors[0][-1], -0.13036962, places=4) - self.assertAlmostEqual(vectors[2][0], -0.3044764, places=4) - self.assertAlmostEqual(vectors[2][-1], 0.1516793, places=4) - - def test_tct_colbert_encoder(self): - encoder = TctColBertDocumentEncoder('castorini/tct_colbert-msmarco', device='cpu') - vectors = encoder.encode(self.texts[:3]) - self.assertAlmostEqual(vectors[0][0], -0.01649557, places=4) - self.assertAlmostEqual(vectors[0][-1], -0.05648308, places=4) - self.assertAlmostEqual(vectors[2][0], -0.10293338, places=4) - self.assertAlmostEqual(vectors[2][-1], 0.05549275, places=4) - - def test_unicoil_encoder(self): - encoder = UniCoilDocumentEncoder('castorini/unicoil-msmarco-passage', device='cpu') - vectors = encoder.encode(self.texts[:3]) - self.assertAlmostEqual(vectors[0]['generation'], 2.2441017627716064, places=4) - self.assertAlmostEqual(vectors[0]['normal'], 2.4618067741394043, places=4) - self.assertAlmostEqual(vectors[2]['rounding'], 3.9474332332611084, places=4) - self.assertAlmostEqual(vectors[2]['commercial'], 3.288801670074463, places=4) - - def test_onnx_encode_unicoil(self): - temp_object = LuceneImpactSearcher(f'{self.index_dir}lucene9-index.cacm', 'SpladePlusPlusEnsembleDistil', encoder_type='onnx') - - # this function will never be called in _impact_searcher, here to check quantization correctness - results = temp_object.encode("here is a test") - self.assertEqual(results.get("here"), 156) - self.assertEqual(results.get("a"), 31) - self.assertEqual(results.get("test"), 149) - - temp_object.close() - del temp_object - - temp_object1 = LuceneImpactSearcher(f'{self.index_dir}lucene9-index.cacm', 'naver/splade-cocondenser-ensembledistil') - - # this function will never be called in _impact_searcher, here to check quantization correctness - results = temp_object1.encode("here is a test") - self.assertEqual(results.get("here"), 156) - self.assertEqual(results.get("a"), 31) - self.assertEqual(results.get("test"), 149) - - temp_object1.close() - del temp_object1 - - @classmethod - def tearDownClass(cls): - os.remove(cls.tarball_name) - shutil.rmtree(cls.index_dir) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_encoder_model_ance.py b/tests/test_encoder_model_ance.py new file mode 100644 index 000000000..1c1d65501 --- /dev/null +++ b/tests/test_encoder_model_ance.py @@ -0,0 +1,42 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyserini.encode import QueryEncoder +from pyserini.search import get_topics + + +class TestEncodeAnce(unittest.TestCase): + def test_ance_encoded_queries(self): + encoded = QueryEncoder.load_encoded_queries('ance-msmarco-passage-dev-subset') + topics = get_topics('msmarco-passage-dev-subset') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + encoded = QueryEncoder.load_encoded_queries('ance-dl19-passage') + topics = get_topics('dl19-passage') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + encoded = QueryEncoder.load_encoded_queries('ance-dl20') + topics = get_topics('dl20') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_encoder_model_distilbert_kd.py b/tests/test_encoder_model_distilbert_kd.py new file mode 100644 index 000000000..bbddeacda --- /dev/null +++ b/tests/test_encoder_model_distilbert_kd.py @@ -0,0 +1,42 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyserini.encode import QueryEncoder +from pyserini.search import get_topics + + +class TestEncodeDistilBertKd(unittest.TestCase): + def test_distilbert_kd_encoded_queries(self): + encoded = QueryEncoder.load_encoded_queries('distilbert_kd-msmarco-passage-dev-subset') + topics = get_topics('msmarco-passage-dev-subset') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + encoded = QueryEncoder.load_encoded_queries('distilbert_kd-dl19-passage') + topics = get_topics('dl19-passage') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + encoded = QueryEncoder.load_encoded_queries('distilbert_kd-dl20') + topics = get_topics('dl20') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_encoder_model_distilbert_tasb.py b/tests/test_encoder_model_distilbert_tasb.py new file mode 100644 index 000000000..88b638363 --- /dev/null +++ b/tests/test_encoder_model_distilbert_tasb.py @@ -0,0 +1,42 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyserini.encode import QueryEncoder +from pyserini.search import get_topics + + +class TestEncodeDistilBertTasB(unittest.TestCase): + def test_distilbert_tas_b_encoded_queries(self): + encoded = QueryEncoder.load_encoded_queries('distilbert_tas_b-msmarco-passage-dev-subset') + topics = get_topics('msmarco-passage-dev-subset') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + encoded = QueryEncoder.load_encoded_queries('distilbert_tas_b-dl19-passage') + topics = get_topics('dl19-passage') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + encoded = QueryEncoder.load_encoded_queries('distilbert_tas_b-dl20') + topics = get_topics('dl20') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_encoder_model_dpr.py b/tests/test_encoder_model_dpr.py new file mode 100644 index 000000000..ca6fbe772 --- /dev/null +++ b/tests/test_encoder_model_dpr.py @@ -0,0 +1,40 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import unittest + +from pyserini.encode import DprDocumentEncoder + + +class TestEncodeDpr(unittest.TestCase): + def test_dpr_encoder(self): + texts = [] + with open('tests/resources/simple_cacm_corpus.json') as f: + for line in f: + line = json.loads(line) + texts.append(line['contents']) + + encoder = DprDocumentEncoder('facebook/dpr-ctx_encoder-multiset-base', device='cpu') + vectors = encoder.encode(texts[:3]) + self.assertAlmostEqual(vectors[0][0], -0.59793323, places=4) + self.assertAlmostEqual(vectors[0][-1], -0.13036962, places=4) + self.assertAlmostEqual(vectors[2][0], -0.3044764, places=4) + self.assertAlmostEqual(vectors[2][-1], 0.1516793, places=4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_encoder_model_sbert.py b/tests/test_encoder_model_sbert.py new file mode 100644 index 000000000..3ef27cfd0 --- /dev/null +++ b/tests/test_encoder_model_sbert.py @@ -0,0 +1,32 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyserini.encode import QueryEncoder +from pyserini.search import get_topics + + +class TestEncodeSBert(unittest.TestCase): + def test_msmarco_passage_sbert_encoded_queries(self): + encoded = QueryEncoder.load_encoded_queries('sbert-msmarco-passage-dev-subset') + topics = get_topics('msmarco-passage-dev-subset') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_encoder_model_tct.py b/tests/test_encoder_model_tct.py new file mode 100644 index 000000000..26e87f1bd --- /dev/null +++ b/tests/test_encoder_model_tct.py @@ -0,0 +1,66 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import unittest + +from pyserini.encode import QueryEncoder +from pyserini.encode import TctColBertDocumentEncoder +from pyserini.search import get_topics + + +class TestEncodeTctColBert(unittest.TestCase): + def test_tct_colbert_encoder(self): + texts = [] + with open('tests/resources/simple_cacm_corpus.json') as f: + for line in f: + line = json.loads(line) + texts.append(line['contents']) + + encoder = TctColBertDocumentEncoder('castorini/tct_colbert-msmarco', device='cpu') + vectors = encoder.encode(texts[:3]) + self.assertAlmostEqual(vectors[0][0], -0.01649557, places=4) + self.assertAlmostEqual(vectors[0][-1], -0.05648308, places=4) + self.assertAlmostEqual(vectors[2][0], -0.10293338, places=4) + self.assertAlmostEqual(vectors[2][-1], 0.05549275, places=4) + + def test_msmarco_doc_tct_colbert_encoded_queries(self): + encoded = QueryEncoder.load_encoded_queries('tct_colbert-msmarco-doc-dev') + topics = get_topics('msmarco-doc-dev') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + def test_msmarco_passage_tct_colbert_v2_encoded_queries(self): + encoded = QueryEncoder.load_encoded_queries('tct_colbert-v2-msmarco-passage-dev-subset') + topics = get_topics('msmarco-passage-dev-subset') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + def test_msmarco_passage_tct_colbert_v2_hn_encoded_queries(self): + encoded = QueryEncoder.load_encoded_queries('tct_colbert-v2-hn-msmarco-passage-dev-subset') + topics = get_topics('msmarco-passage-dev-subset') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + def test_msmarco_passage_tct_colbert_v2_hnp_encoded_queries(self): + encoded = QueryEncoder.load_encoded_queries('tct_colbert-v2-hnp-msmarco-passage-dev-subset') + topics = get_topics('msmarco-passage-dev-subset') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_encoder_model_unicoil.py b/tests/test_encoder_model_unicoil.py new file mode 100644 index 000000000..cbf1b0251 --- /dev/null +++ b/tests/test_encoder_model_unicoil.py @@ -0,0 +1,85 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import os +import shutil +import tarfile +import unittest +from random import randint +from urllib.request import urlretrieve + +from pyserini.encode import UniCoilDocumentEncoder +from pyserini.search.lucene import LuceneImpactSearcher + + +class TestEncodeUniCoil(unittest.TestCase): + def test_unicoil_encoder(self): + texts = [] + with open('tests/resources/simple_cacm_corpus.json') as f: + for line in f: + line = json.loads(line) + texts.append(line['contents']) + + encoder = UniCoilDocumentEncoder('castorini/unicoil-msmarco-passage', device='cpu') + vectors = encoder.encode(texts[:3]) + self.assertAlmostEqual(vectors[0]['generation'], 2.2441017627716064, places=4) + self.assertAlmostEqual(vectors[0]['normal'], 2.4618067741394043, places=4) + self.assertAlmostEqual(vectors[2]['rounding'], 3.9474332332611084, places=4) + self.assertAlmostEqual(vectors[2]['commercial'], 3.288801670074463, places=4) + + def test_onnx_encode_unicoil(self): + # LuceneImpactSearcher requires a pre-built index to be initialized + r = randint(0, 10000000) + collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene9-index.cacm.tar.gz' + tarball_name = f'lucene-index.cacm-{r}.tar.gz' + index_dir = f'index-{r}/' + + urlretrieve(collection_url, tarball_name) + + tarball = tarfile.open(tarball_name) + tarball.extractall(index_dir) + tarball.close() + + searcher1 = LuceneImpactSearcher(f'{index_dir}lucene9-index.cacm', + 'SpladePlusPlusEnsembleDistil', + encoder_type='onnx') + + results = searcher1.encode("here is a test") + self.assertEqual(results.get("here"), 156) + self.assertEqual(results.get("a"), 31) + self.assertEqual(results.get("test"), 149) + + searcher1.close() + del searcher1 + + searcher2 = LuceneImpactSearcher(f'{index_dir}lucene9-index.cacm', + 'naver/splade-cocondenser-ensembledistil') + + results = searcher2.encode("here is a test") + self.assertEqual(results.get("here"), 156) + self.assertEqual(results.get("a"), 31) + self.assertEqual(results.get("test"), 149) + + searcher2.close() + del searcher2 + + os.remove(tarball_name) + shutil.rmtree(index_dir) + + +if __name__ == '__main__': + unittest.main() From a6d9cc2fefa4d03a1239c81a8ec93b55ffcce43a Mon Sep 17 00:00:00 2001 From: lintool Date: Wed, 9 Oct 2024 10:24:49 -0700 Subject: [PATCH 6/6] Fixed dsearch import bug, more tests. --- pyserini/dsearch.py | 8 +++--- tests/test_encoder_model_ance.py | 18 +++++++++++++- tests/test_encoder_model_distilbert_kd.py | 18 +++++++++++++- tests/test_encoder_model_distilbert_tasb.py | 18 +++++++++++++- tests/test_encoder_model_dpr.py | 27 +++++++++++++++++++-- 5 files changed, 79 insertions(+), 10 deletions(-) diff --git a/pyserini/dsearch.py b/pyserini/dsearch.py index b65956aba..f06bcee8a 100644 --- a/pyserini/dsearch.py +++ b/pyserini/dsearch.py @@ -20,10 +20,8 @@ import os import sys -from pyserini.search.faiss import FaissSearcher -from pyserini.search.faiss._searcher import TctColBertQueryEncoder, BinaryDenseSearcher - -__all__ = ['SimpleDenseSearcher', 'BinaryDenseSearcher', 'TctColBertQueryEncoder'] +from pyserini.encode import TctColBertQueryEncoder +from pyserini.search.faiss import FaissSearcher, BinaryDenseFaissSearcher class SimpleDenseSearcher(FaissSearcher): @@ -33,7 +31,7 @@ def __new__(cls, *args, **kwargs): return super().__new__(cls) -class BinaryDenseSearcher(BinaryDenseSearcher): +class BinaryDenseSearcher(BinaryDenseFaissSearcher): def __new__(cls, *args, **kwargs): print('pyserini.dsearch.BinaryDenseSearcher class has been deprecated, ' 'please use BinaryDenseSearcher from pyserini.search.faiss instead') diff --git a/tests/test_encoder_model_ance.py b/tests/test_encoder_model_ance.py index 1c1d65501..f96dff7e5 100644 --- a/tests/test_encoder_model_ance.py +++ b/tests/test_encoder_model_ance.py @@ -15,8 +15,11 @@ # import unittest +from itertools import islice -from pyserini.encode import QueryEncoder +import numpy as np + +from pyserini.encode import QueryEncoder, AnceQueryEncoder from pyserini.search import get_topics @@ -37,6 +40,19 @@ def test_ance_encoded_queries(self): for t in topics: self.assertTrue(topics[t]['title'] in encoded.embedding) + def test_ance_encoder(self): + encoder = AnceQueryEncoder('castorini/ance-msmarco-passage') + + cached_encoder = QueryEncoder.load_encoded_queries('ance-dl20') + topics = get_topics('dl20') + # Just test the first 10 topics + for t in dict(islice(topics.items(), 10)): + cached_vector = np.array(cached_encoder.encode(topics[t]['title'])) + encoded_vector = np.array(encoder.encode(topics[t]['title'])) + + l1 = np.sum(np.abs(cached_vector - encoded_vector)) + self.assertTrue(l1 < 0.0005) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_encoder_model_distilbert_kd.py b/tests/test_encoder_model_distilbert_kd.py index bbddeacda..fc7f8efe3 100644 --- a/tests/test_encoder_model_distilbert_kd.py +++ b/tests/test_encoder_model_distilbert_kd.py @@ -15,8 +15,11 @@ # import unittest +from itertools import islice -from pyserini.encode import QueryEncoder +import numpy as np + +from pyserini.encode import QueryEncoder, AutoQueryEncoder from pyserini.search import get_topics @@ -37,6 +40,19 @@ def test_distilbert_kd_encoded_queries(self): for t in topics: self.assertTrue(topics[t]['title'] in encoded.embedding) + def test_distilbert_kd_encoder(self): + encoder = AutoQueryEncoder('sebastian-hofstaetter/distilbert-dot-margin_mse-T2-msmarco') + + cached_encoder = QueryEncoder.load_encoded_queries('distilbert_kd-dl20') + topics = get_topics('dl20') + # Just test the first 10 topics + for t in dict(islice(topics.items(), 10)): + cached_vector = np.array(cached_encoder.encode(topics[t]['title'])) + encoded_vector = np.array(encoder.encode(topics[t]['title'])) + + l1 = np.sum(np.abs(cached_vector - encoded_vector)) + self.assertTrue(l1 < 0.0005) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_encoder_model_distilbert_tasb.py b/tests/test_encoder_model_distilbert_tasb.py index 88b638363..33ba6c9f6 100644 --- a/tests/test_encoder_model_distilbert_tasb.py +++ b/tests/test_encoder_model_distilbert_tasb.py @@ -15,8 +15,11 @@ # import unittest +from itertools import islice -from pyserini.encode import QueryEncoder +import numpy as np + +from pyserini.encode import QueryEncoder, AutoQueryEncoder from pyserini.search import get_topics @@ -37,6 +40,19 @@ def test_distilbert_tas_b_encoded_queries(self): for t in topics: self.assertTrue(topics[t]['title'] in encoded.embedding) + def test_distilbert_tas_b_encoder(self): + encoder = AutoQueryEncoder('sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco') + + cached_encoder = QueryEncoder.load_encoded_queries('distilbert_tas_b-dl20') + topics = get_topics('dl20') + # Just test the first 10 topics + for t in dict(islice(topics.items(), 10)): + cached_vector = np.array(cached_encoder.encode(topics[t]['title'])) + encoded_vector = np.array(encoder.encode(topics[t]['title'])) + + l1 = np.sum(np.abs(cached_vector - encoded_vector)) + self.assertTrue(l1 < 0.0005) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_encoder_model_dpr.py b/tests/test_encoder_model_dpr.py index ca6fbe772..790fd5708 100644 --- a/tests/test_encoder_model_dpr.py +++ b/tests/test_encoder_model_dpr.py @@ -16,12 +16,16 @@ import json import unittest +from itertools import islice -from pyserini.encode import DprDocumentEncoder +import numpy as np + +from pyserini.encode import DprDocumentEncoder, DprQueryEncoder +from pyserini.search import get_topics class TestEncodeDpr(unittest.TestCase): - def test_dpr_encoder(self): + def test_dpr_doc_encoder(self): texts = [] with open('tests/resources/simple_cacm_corpus.json') as f: for line in f: @@ -35,6 +39,25 @@ def test_dpr_encoder(self): self.assertAlmostEqual(vectors[2][0], -0.3044764, places=4) self.assertAlmostEqual(vectors[2][-1], 0.1516793, places=4) + def test_dpr_encoded_queries(self): + encoded = DprQueryEncoder.load_encoded_queries('dpr_multi-nq-test') + topics = get_topics('dpr-nq-test') + for t in topics: + self.assertTrue(topics[t]['title'] in encoded.embedding) + + def test_dpr_query_encoder(self): + encoder = DprQueryEncoder('facebook/dpr-question_encoder-multiset-base') + + cached_encoder = DprQueryEncoder.load_encoded_queries('dpr_multi-nq-test') + topics = get_topics('dpr-nq-test') + # Just test the first 10 topics + for t in dict(islice(topics.items(), 10)): + cached_vector = np.array(cached_encoder.encode(topics[t]['title'])) + encoded_vector = np.array(encoder.encode(topics[t]['title'])) + + l1 = np.sum(np.abs(cached_vector - encoded_vector)) + self.assertTrue(l1 < 0.0005) + if __name__ == '__main__': unittest.main()