Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add index pointer of each token within doc_freq #80

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions pinecone_text/sparse/bm25_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm.auto import tqdm
import wget
from typing import List, Optional, Dict, Union, Tuple
from collections import Counter
from collections import Counter, OrderedDict

from pinecone_text.sparse import SparseVector
from pinecone_text.sparse.base_sparse_encoder import BaseSparseEncoder
Expand All @@ -26,6 +26,7 @@ def __init__(
remove_stopwords: bool = True,
stem: bool = True,
language: str = "english",
indptrs: bool = False,
):
"""
OKapi BM25 with mmh3 hashing
Expand All @@ -38,6 +39,7 @@ def __init__(
remove_stopwords: Whether to remove stopwords tokens
stem: Whether to stem the tokens (using SnowballStemmer)
language: The language of the text (used for stopwords and stemmer)
indptrs: Whether to return token positions within document frequency, to form a scipy.sparse array

Example:

Expand All @@ -55,6 +57,7 @@ def __init__(
# Fixed params
self.b: float = b
self.k1: float = k1
self.indptrs: bool = indptrs

self._tokenizer = BM25Tokenizer(
lower_case=lower_case,
Expand Down Expand Up @@ -118,17 +121,20 @@ def encode_documents(
raise ValueError("texts must be a string or list of strings")

def _encode_single_document(self, text: str) -> SparseVector:
indices, doc_tf = self._tf(text)
indptrs, indices, doc_tf = self._tf(text)
tf = np.array(doc_tf)
tf_sum = sum(tf)

tf_normed = tf / (
self.k1 * (1.0 - self.b + self.b * (tf_sum / self.avgdl)) + tf
)
return {
"indices": indices,
"values": tf_normed.tolist(),
}

encoded_document = OrderedDict()
if self.indptrs:
encoded_document["indptrs"] = indptrs
encoded_document["indices"] = indices
encoded_document["values"] = tf_normed.tolist()
return encoded_document

def encode_queries(
self, texts: Union[str, List[str]]
Expand Down Expand Up @@ -267,18 +273,19 @@ def _hash_text(token: str) -> int:
"""Use mmh3 to hash text to 32-bit unsigned integer"""
return mmh3.hash(token, signed=False)

def _tf(self, text: str) -> Tuple[List[int], List[int]]:
def _tf(self, text: str) -> Tuple[List[int], List[int], [List[int]]]:
"""
Calculate term frequency for a given text

Args:
text: a document to calculate term frequency for

Returns: a tuple of two lists:
Returns: a tuple of three lists:
indptrs: list of position pointers
indices: list of term indices
values: list of term frequencies
"""
counts = Counter((self._hash_text(token) for token in self._tokenizer(text)))

items = list(counts.items())
return [idx for idx, _ in items], [val for _, val in items]
return [self.doc_freq.index(idx) for idx, _ in items], [idx for idx, _ in items], [val for _, val in items]