Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update AToMiC demo page to support dense search #1721

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 155 additions & 44 deletions pyserini/demo/atomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,41 +29,134 @@
from typing import Callable, Optional, Tuple, Union

from flask import Flask, render_template, request, flash, jsonify
from pyserini.search import LuceneSearcher, FaissSearcher
from pyserini.search import LuceneSearcher, FaissSearcher, QueryEncoder


RETRIEVER_TO_INDEXES = {
'BM25': [
'atomic_image_v0.2_small_validation',
'atomic_image_v0.2_base',
'atomic_image_v0.2_large',
'atomic_text_v0.2.1_small_validation',
'atomic_text_v0.2.1_base',
'atomic_text_v0.2.1_large',
],
'ViT-L-14.laion2b_s32b_b82k': [
'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.validation',
'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.base',
'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.large',
'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.validation',
'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.base',
'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.large',
],
'ViT-H-14.laion2b_s32b_b79k': [
'atomic-v0.2.ViT-H-14.laion2b_s32b_b79k.image.large',
'atomic-v0.2.1.ViT-H-14.laion2b_s32b_b79k.text.large',
],
'ViT-bigG-14.laion2b_s39b_b160k': [
'atomic-v0.2.ViT-bigG-14.laion2b_s39b_b160k.image.large',
'atomic-v0.2.1.ViT-bigG-14.laion2b_s39b_b160k.text.large',
],
'ViT-B-32.laion2b_e16': [
'atomic-v0.2.ViT-B-32.laion2b_e16.image.large',
'atomic-v0.2.1.ViT-B-32.laion2b_e16.text.large',
],
'ViT-B-32.laion400m_e32': [
'atomic-v0.2.ViT-B-32.laion400m_e32.image.large',
'atomic-v0.2.1.ViT-B-32.laion400m_e32.text.large',
],
'openai.clip-vit-base-patch32': [
'atomic-v0.2.openai.clip-vit-base-patch32.image.large',
'atomic-v0.2.1.openai.clip-vit-base-patch32.text.large',
],
'openai.clip-vit-large-patch14': [
'atomic-v0.2.openai.clip-vit-large-patch14.image.large',
'atomic-v0.2.1.openai.clip-vit-large-patch14.text.large',
],
'Salesforce.blip-itm-base-coco': [
'atomic-v0.2.Salesforce.blip-itm-base-coco.image.large',
'atomic-v0.2.1.Salesforce.blip-itm-base-coco.text.large',
],
'Salesforce.blip-itm-large-coco': [
'atomic-v0.2.Salesforce.blip-itm-large-coco.image.large',
'atomic-v0.2.1.Salesforce.blip-itm-large-coco.text.large',
],
'facebook.flava-full': [
'atomic-v0.2.facebook.flava-full.image.large',
'atomic-v0.2.1.facebook.flava-full.text.large',
],
}

INDEX_TO_ENCODED_QUERIES = {
# 'ViT-L-14.laion2b_s32b_b82k'
'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.validation': 'atomic-v0.2-image-ViT-L-14.laion2b_s32b_b82k-validation',
'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.validation': 'atomic-v0.2.1-text-ViT-L-14.laion2b_s32b_b82k-validation',
'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.base': 'atomic-v0.2-image-ViT-L-14.laion2b_s32b_b82k-validation',
'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.base': 'atomic-v0.2.1-text-ViT-L-14.laion2b_s32b_b82k-validation',
'atomic-v0.2.1.ViT-L-14.laion2b_s32b_b82k.text.large': 'atomic-v0.2-image-ViT-L-14.laion2b_s32b_b82k-validation',
'atomic-v0.2.ViT-L-14.laion2b_s32b_b82k.image.large': 'atomic-v0.2.1-text-ViT-L-14.laion2b_s32b_b82k-validation',
# ViT-H-14.laion2b_s32b_b79k
'atomic-v0.2.ViT-H-14.laion2b_s32b_b79k.image.large': 'atomic-v0.2.1-text-ViT-H-14.laion2b_s32b_b79k-validation',
'atomic-v0.2.1.ViT-H-14.laion2b_s32b_b79k.text.large': 'atomic-v0.2-image-ViT-H-14.laion2b_s32b_b79k-validation',
# ViT-bigG-14.laion2b_s39b_b160k
'atomic-v0.2.ViT-bigG-14.laion2b_s39b_b160k.image.large': 'atomic-v0.2.1-text-ViT-bigG-14.laion2b_s39b_b160k-validation',
'atomic-v0.2.1.ViT-bigG-14.laion2b_s39b_b160k.text.large': 'atomic-v0.2-image-ViT-bigG-14.laion2b_s39b_b160k-validation',
# ViT-B-32.laion2b_e16
'atomic-v0.2.ViT-B-32.laion2b_e16.image.large': 'atomic-v0.2.1-text-ViT-B-32.laion2b_e16-validation',
'atomic-v0.2.1.ViT-B-32.laion2b_e16.text.large': 'atomic-v0.2-image-ViT-B-32.laion2b_e16-validation',
# ViT-B-32.laion400m_e32
'atomic-v0.2.ViT-B-32.laion400m_e32.image.large': 'atomic-v0.2.1-text-ViT-B-32.laion400m_e32-validation',
'atomic-v0.2.1.ViT-B-32.laion400m_e32.text.large': 'atomic-v0.2-image-ViT-B-32.laion400m_e32-validation',
# openai.clip-vit-base-patch32
'atomic-v0.2.openai.clip-vit-base-patch32.image.large': 'atomic-v0.2.1-text-openai.clip-vit-base-patch32-validation',
'atomic-v0.2.1.openai.clip-vit-base-patch32.text.large': 'atomic-v0.2-image-openai.clip-vit-base-patch32-validation',
# openai.clip-vit-large-patch14
'atomic-v0.2.openai.clip-vit-large-patch14.image.large': 'atomic-v0.2.1-text-openai.clip-vit-large-patch14-validation',
'atomic-v0.2.1.openai.clip-vit-large-patch14.text.large': 'atomic-v0.2-image-openai.clip-vit-large-patch14-validation',
# Salesforce.blip-itm-base-coco
'atomic-v0.2.Salesforce.blip-itm-base-coco.image.large': 'atomic-v0.2.1-text-Salesforce.blip-itm-base-coco-validation',
'atomic-v0.2.1.Salesforce.blip-itm-base-coco.text.large': 'atomic-v0.2-image-Salesforce.blip-itm-base-coco-validation',
# Salesforce.blip-itm-large-coco
'atomic-v0.2.Salesforce.blip-itm-large-coco.image.large': 'atomic-v0.2.1-text-Salesforce.blip-itm-large-coco-validation',
'atomic-v0.2.1.Salesforce.blip-itm-large-coco.text.large': 'atomic-v0.2-image-Salesforce.blip-itm-large-coco-validation',
# facebook.flava-full
'atomic-v0.2.facebook.flava-full.image.large': 'atomic-v0.2.1-text-facebook.flava-full-validation',
'atomic-v0.2.1.facebook.flava-full.text.large': 'atomic-v0.2-image-facebook.flava-full-validation',
}


INDEX_NAMES = (
'atomic_image_v0.2_small_validation',
'atomic_image_v0.2_base',
'atomic_image_v0.2_large',
'atomic_text_v0.2.1_small_validation',
'atomic_text_v0.2.1_base',
'atomic_text_v0.2.1_large',
)
Searcher = Union[FaissSearcher, LuceneSearcher]


def create_app(k: int, load_searcher_fn: Callable[[str], Tuple[Searcher, str]]):
def create_app(k: int, load_searcher_fn: Callable[[str], Searcher]):
app = Flask(__name__)

index_name = INDEX_NAMES[0]
searcher, retriever = load_searcher_fn(index_name=index_name)
# Use BM25 as default retriever upon page load
retriever = "BM25"
index_name = RETRIEVER_TO_INDEXES[retriever][0]
searcher = load_searcher_fn(index_name=index_name)
query_options = [] # for dense search only

@app.route('/')
def index():
nonlocal searcher, retriever
return render_template('atomic.html', index_name=index_name, retriever=retriever)
return render_template(
'atomic.html', index_name=index_name, retriever=retriever, retriever_to_indexes=RETRIEVER_TO_INDEXES
)

@app.route('/search', methods=['GET', 'POST'])
def search():
nonlocal searcher, retriever
query = request.form['q']
if retriever != "BM25":
query = query_options[int(query)]
if not query:
search_results = []
flash('Question is required')
# NOTE: this throws an exception unless we set a secret session key
else:
hits = searcher.search(query, k=k)
try:
hits = searcher.search(query, k=k)
except KeyError:
hits = []
flash('Invalid query given')
docs = [json.loads(searcher.doc(hit.docid).raw()) for hit in hits]
search_results = [
{
Expand All @@ -76,56 +169,74 @@ def search():
for r, hit in enumerate(hits)
]
return render_template(
'atomic.html', index_name=index_name, search_results=search_results, query=query, retriever=retriever
'atomic.html', index_name=index_name, retriever=retriever,
retriever_to_indexes=RETRIEVER_TO_INDEXES, search_results=search_results, query=query,
)

def _change_index(new_index_name):
nonlocal index_name, searcher, query_options
index_name = new_index_name
searcher = load_searcher_fn(index_name=index_name)
if retriever != "BM25":
query_options = {i: option for i, option in enumerate(searcher.query_encoder.embedding.keys())}

@app.route('/retriever', methods=['GET'])
def change_retriever():
nonlocal retriever
new_retriever = request.args.get('new_retriever_name', '', type=str)
if not new_retriever or new_retriever not in list(RETRIEVER_TO_INDEXES.keys()):
return

retriever = new_retriever
_change_index(new_index_name=RETRIEVER_TO_INDEXES[retriever][0])
return jsonify(index_list=RETRIEVER_TO_INDEXES[retriever])

@app.route('/index', methods=['GET'])
def change_index_name():
nonlocal index_name, searcher, retriever
new_index_name = request.args.get('new_index_name', '', type=str)
if not new_index_name or new_index_name not in INDEX_NAMES:
if not new_index_name or new_index_name not in RETRIEVER_TO_INDEXES[retriever]:
return

index_name = new_index_name
searcher, retriever = load_searcher_fn(index_name=index_name)
_change_index(new_index_name)
return jsonify(index_name=index_name)

@app.route('/search_options', methods=['GET'])
def search_options():
query = request.args.get('query', '')

matching_options = {
i: option
for i, option in query_options.items()
if option.lower().startswith(query.lower())
}
return jsonify(matching_options)

return app


def _load_sparse_searcher(index_name, language: str, k1: Optional[float]=None, b: Optional[float]=None) -> (Searcher, str):
searcher = LuceneSearcher.from_prebuilt_index(index_name)
if k1 is not None and b is not None:
searcher.set_bm25(k1, b)
retriever_name = f'BM25 (k1={k1}, b={b})'
def _load_searcher(index_name: str, language: str, k1: Optional[float]=None, b: Optional[float]=None):
if index_name in RETRIEVER_TO_INDEXES['BM25']:
searcher = LuceneSearcher.from_prebuilt_index(index_name)
if k1 is not None and b is not None:
searcher.set_bm25(k1, b)
else:
retriever_name = 'BM25'

return searcher, retriever_name
query_encoder = QueryEncoder.load_encoded_queries(INDEX_TO_ENCODED_QUERIES[index_name])
searcher = FaissSearcher.from_prebuilt_index(
index_name, query_encoder
)
return searcher


def main():
parser = ArgumentParser()

parser.add_argument('--k1', type=float, help='BM25 k1 parameter.')
parser.add_argument('--b', type=float, help='BM25 b parameter.')
parser.add_argument('--hits', type=int, default=10, help='Number of hits returned by the retriever')
parser.add_argument(
'--device',
type=str,
default='cpu',
help='Device to run query encoder, cpu or [cuda:0, cuda:1, ...] (used only when index is based on FAISS)',
)
parser.add_argument(
'--port',
default=8080,
type=int,
help='Web server port',
'--port', default=8080, type=int, help='Web server port',
)

args = parser.parse_args()
load_fn = partial(_load_sparse_searcher, language='en', k1=args.k1, b=args.b)

load_fn = partial(_load_searcher, language='en', k1=args.k1, b=args.b)
app = create_app(args.hits, load_fn)
app.run(host='0.0.0.0', port=args.port)

Expand Down
Loading