Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring to remove duplicates in pyserini.encode and pyserini.search.faiss #2008

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations-optional/dense/test_ance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import unittest

from integrations.utils import clean_files, run_command, parse_score_qa, parse_score_msmarco
from pyserini.encode import QueryEncoder
from pyserini.search import get_topics
from pyserini.search.faiss._searcher import QueryEncoder


class TestAnce(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion integrations-optional/dense/test_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import unittest

from integrations.utils import clean_files, run_command, parse_score_qa
from pyserini.encode import QueryEncoder
from pyserini.search import get_topics
from pyserini.search.faiss._searcher import QueryEncoder


class TestDpr(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion integrations-optional/dense/test_tct_colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import unittest

from integrations.utils import clean_files, run_command, parse_score
from pyserini.encode import QueryEncoder
from pyserini.search import get_topics
from pyserini.search.faiss._searcher import QueryEncoder


class TestTctColBert(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class TestLtrMsmarcoDocument(unittest.TestCase):
def test_reranking(self):
if(os.path.isdir('ltr_test')):
if os.path.isdir('ltr_test'):
rmtree('ltr_test')
os.mkdir('ltr_test')
inp = 'run.msmarco-pass-doc.bm25.txt'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@
import unittest
from shutil import rmtree

from pyserini.search.lucene import LuceneSearcher


class TestLtrMsmarcoPassage(unittest.TestCase):
def test_reranking(self):
if(os.path.isdir('ltr_test')):
if os.path.isdir('ltr_test'):
rmtree('ltr_test')
os.mkdir('ltr_test')
inp = 'run.msmarco-passage.bm25tuned.txt'
Expand Down
5 changes: 3 additions & 2 deletions integrations/clprf/test_trec_covid_r5.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
# limitations under the License.
#

import gzip
import json
import os
import re
import shutil
import unittest
import json
import gzip
from random import randint

from pyserini.util import download_url, download_prebuilt_index


Expand Down
2 changes: 1 addition & 1 deletion integrations/lucenesearcher_anserini_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import os
from typing import List

from pyserini.util import get_cache_home
from pyserini.prebuilt_index_info import TF_INDEX_INFO
from pyserini.util import get_cache_home


class LuceneSearcherAnseriniMatchChecker:
Expand Down
2 changes: 1 addition & 1 deletion integrations/sparse/test_lucenesearcher_check_irst.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

import os
import unittest
from shutil import rmtree
from random import randint
from shutil import rmtree

from integrations.utils import run_command, parse_score

Expand Down
2 changes: 1 addition & 1 deletion integrations/sparse/test_search_pretokenized.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import os
import shutil
import unittest

from random import randint

from integrations.lucenesearcher_score_checker import LuceneSearcherScoreChecker


Expand Down
2 changes: 1 addition & 1 deletion integrations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
#

import os
import subprocess
import shutil
import subprocess


def clean_files(files):
Expand Down
1 change: 0 additions & 1 deletion pyserini/analysis/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@

# Wrappers around Anserini classes
JAnalyzerUtils = autoclass('io.anserini.analysis.AnalyzerUtils')
JDefaultEnglishAnalyzer = autoclass('io.anserini.analysis.DefaultEnglishAnalyzer')
JTweetAnalyzer = autoclass('io.anserini.analysis.TweetAnalyzer')
JHuggingFaceTokenizerAnalyzer = autoclass('io.anserini.analysis.HuggingFaceTokenizerAnalyzer')

Expand Down
3 changes: 2 additions & 1 deletion pyserini/demo/dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import json
import random

from pyserini.encode import DprQueryEncoder
from pyserini.search import get_topics
from pyserini.search.faiss import FaissSearcher, DprQueryEncoder
from pyserini.search.faiss import FaissSearcher
from pyserini.search.hybrid import HybridSearcher
from pyserini.search.lucene import LuceneSearcher

Expand Down
3 changes: 2 additions & 1 deletion pyserini/demo/msmarco.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import json
import random

from pyserini.encode import AnceQueryEncoder, TctColBertQueryEncoder
from pyserini.search import get_topics
from pyserini.search.faiss import FaissSearcher, TctColBertQueryEncoder, AnceQueryEncoder
from pyserini.search.faiss import FaissSearcher
from pyserini.search.hybrid import HybridSearcher
from pyserini.search.lucene import LuceneSearcher

Expand Down
8 changes: 3 additions & 5 deletions pyserini/dsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@
import os
import sys

from pyserini.search.faiss import FaissSearcher
from pyserini.search.faiss._searcher import TctColBertQueryEncoder, BinaryDenseSearcher

__all__ = ['SimpleDenseSearcher', 'BinaryDenseSearcher', 'TctColBertQueryEncoder']
from pyserini.encode import TctColBertQueryEncoder
from pyserini.search.faiss import FaissSearcher, BinaryDenseFaissSearcher


class SimpleDenseSearcher(FaissSearcher):
Expand All @@ -33,7 +31,7 @@ def __new__(cls, *args, **kwargs):
return super().__new__(cls)


class BinaryDenseSearcher(BinaryDenseSearcher):
class BinaryDenseSearcher(BinaryDenseFaissSearcher):
def __new__(cls, *args, **kwargs):
print('pyserini.dsearch.BinaryDenseSearcher class has been deprecated, '
'please use BinaryDenseSearcher from pyserini.search.faiss instead')
Expand Down
2 changes: 2 additions & 0 deletions pyserini/encode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
from ._aggretriever import AggretrieverDocumentEncoder, AggretrieverQueryEncoder
from ._ance import AnceEncoder, AnceDocumentEncoder, AnceQueryEncoder
from ._auto import AutoQueryEncoder, AutoDocumentEncoder
from ._bpr import BprQueryEncoder
from ._cached_data import CachedDataQueryEncoder
from ._cosdpr import CosDprEncoder, CosDprDocumentEncoder, CosDprQueryEncoder
from ._dkrr import DkrrDprQueryEncoder
from ._dpr import DprDocumentEncoder, DprQueryEncoder
from ._openai import OpenAIDocumentEncoder, OpenAIQueryEncoder, OPENAI_API_RETRY_DELAY
from ._slim import SlimQueryEncoder
Expand Down
64 changes: 34 additions & 30 deletions pyserini/encode/_aggretriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from transformers import AutoModelForMaskedLM, AutoTokenizer, PreTrainedModel
from pyserini.encode import DocumentEncoder, QueryEncoder

class BERTAggretrieverEncoder(PreTrainedModel):

class BertAggretrieverEncoder(PreTrainedModel):
config_class = BertConfig
base_model_prefix = 'encoder'
load_tf_weights = None
Expand Down Expand Up @@ -120,7 +121,7 @@ def forward(
return torch.cat((semantic_reps, lexical_reps), -1)


class DistlBERTAggretrieverEncoder(BERTAggretrieverEncoder):
class DistlBertAggretrieverEncoder(BertAggretrieverEncoder):
config_class = DistilBertConfig
base_model_prefix = 'encoder'
load_tf_weights = None
Expand All @@ -130,9 +131,9 @@ class AggretrieverDocumentEncoder(DocumentEncoder):
def __init__(self, model_name: str, tokenizer_name=None, device='cuda:0'):
self.device = device
if 'distilbert' in model_name.lower():
self.model = DistlBERTAggretrieverEncoder.from_pretrained(model_name)
self.model = DistlBertAggretrieverEncoder.from_pretrained(model_name)
else:
self.model = BERTAggretrieverEncoder.from_pretrained(model_name)
self.model = BertAggretrieverEncoder.from_pretrained(model_name)
self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name)

Expand Down Expand Up @@ -160,30 +161,33 @@ def encode(self, texts, titles=None, fp16=False, max_length=512, **kwargs):


class AggretrieverQueryEncoder(QueryEncoder):
def __init__(self, model_name: str, tokenizer_name=None, device='cuda:0'):
self.device = device
if 'distilbert' in model_name.lower():
self.model = DistlBERTAggretrieverEncoder.from_pretrained(model_name)
else:
self.model = BERTAggretrieverEncoder.from_pretrained(model_name)
self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name)

def encode(self, texts, fp16=False, max_length=32, **kwargs):
texts = [text for text in texts]
inputs = self.tokenizer(
texts,
max_length=max_length,
padding="longest",
truncation=True,
add_special_tokens=True,
return_tensors='pt'
)
inputs.to(self.device)
if fp16:
with autocast():
with torch.no_grad():
outputs = self.model(**inputs)
else:
def __init__(self, encoder_dir: str = None, tokenizer_name: str = None,
encoded_query_dir: str = None, device: str = 'cpu', **kwargs):
if encoder_dir:
self.device = device
if 'distilbert' in encoder_dir.lower():
self.model = DistlBertAggretrieverEncoder.from_pretrained(encoder_dir)
else:
self.model = BertAggretrieverEncoder.from_pretrained(encoder_dir)
self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or encoder_dir)
self.has_model = True
if (not self.has_model) and (not self.has_encoded_query):
raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one')

def encode(self, query: str, max_length: int=32):
if self.has_model:
inputs = self.tokenizer(
query,
max_length=max_length,
padding="longest",
truncation=True,
add_special_tokens=True,
return_tensors='pt'
)
inputs.to(self.device)
outputs = self.model(**inputs)
return outputs.detach().cpu().numpy()
embeddings = outputs.detach().cpu().numpy()
return embeddings.flatten()
else:
return super().encode(query)
71 changes: 53 additions & 18 deletions pyserini/encode/_ance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# limitations under the License.
#

from typing import Optional
from typing import Optional, List

import torch
from transformers import PreTrainedModel, RobertaConfig, RobertaModel, RobertaTokenizer
from transformers import PreTrainedModel, RobertaConfig, RobertaModel, RobertaTokenizer, requires_backends

from pyserini.encode import DocumentEncoder, QueryEncoder

Expand All @@ -30,6 +30,7 @@ class AnceEncoder(PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r'pooler', r'classifier']

def __init__(self, config: RobertaConfig):
requires_backends(self, 'torch')
super().__init__(config)
self.config = config
self.roberta = RobertaModel(config)
Expand All @@ -55,11 +56,7 @@ def init_weights(self):
self.embeddingHead.apply(self._init_weights)
self.norm.apply(self._init_weights)

def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
input_shape = input_ids.size()
device = input_ids.device
if attention_mask is None:
Expand Down Expand Up @@ -98,22 +95,60 @@ def encode(self, texts, titles=None, max_length=256, **kwargs):


class AnceQueryEncoder(QueryEncoder):
def __init__(self, encoder_dir: str = None, tokenizer_name: str = None,
encoded_query_dir: str = None, device: str = 'cpu', **kwargs):
super().__init__(encoded_query_dir)
if encoder_dir:
self.device = device
self.model = AnceEncoder.from_pretrained(encoder_dir)
self.model.to(self.device)
self.tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name or encoder_dir)
self.has_model = True
self.tokenizer.do_lower_case = True
if (not self.has_model) and (not self.has_encoded_query):
raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one')

def encode(self, query: str):
if self.has_model:
inputs = self.tokenizer(
[query],
max_length=64,
padding='longest',
truncation=True,
add_special_tokens=True,
return_tensors='pt'
)
inputs.to(self.device)
embeddings = self.model(inputs["input_ids"]).detach().cpu().numpy()
return embeddings.flatten()
else:
return super().encode(query)

def prf_encode(self, query: str):
if self.has_model:
inputs = self.tokenizer(
[query],
max_length=512,
padding='longest',
truncation=True,
add_special_tokens=False,
return_tensors='pt'
)
inputs.to(self.device)
embeddings = self.model(inputs["input_ids"]).detach().cpu().numpy()
return embeddings.flatten()
else:
return super().encode(query)

def __init__(self, model_name: str, tokenizer_name: str = None, device: str = 'cpu'):
self.device = device
self.model = AnceEncoder.from_pretrained(model_name)
self.model.to(self.device)
self.tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name or model_name)

def encode(self, query: str, **kwargs):
def prf_batch_encode(self, query: List[str]):
inputs = self.tokenizer(
[query],
max_length=64,
query,
max_length=512,
padding='longest',
truncation=True,
add_special_tokens=True,
add_special_tokens=False,
return_tensors='pt'
)
inputs.to(self.device)
embeddings = self.model(inputs["input_ids"]).detach().cpu().numpy()
return embeddings.flatten()
return embeddings
Loading