Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added optional input "model_dir" to class SpladeEncoder #79

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions pinecone_text/sparse/splade_encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List, Union, Optional

from os import PathLike
import os
import json
try:
import torch
except (OSError, ImportError, ModuleNotFoundError) as e:
Expand All @@ -26,11 +28,12 @@ class SpladeEncoder(BaseSparseEncoder):
Currently only supports inference with naver/splade-cocondenser-ensembledistil
"""

def __init__(self, max_seq_length: int = 256, device: Optional[str] = None):
def __init__(self, max_seq_length: int = 256, device: Optional[str] = None, model_dir:Optional[PathLike[str]] = None):
"""
Args:
max_seq_length: Maximum sequence length for the model. Must be between 1 and 512.
device: Device to use for inference. Defaults to GPU if available, otherwise CPU.
model_dir: Directory to download and load model from. Saves time and resources.

Example:

Expand Down Expand Up @@ -61,12 +64,38 @@ def __init__(self, max_seq_length: int = 256, device: Optional[str] = None):

device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.device = device

model = "naver/splade-cocondenser-ensembledistil"
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.model = AutoModelForMaskedLM.from_pretrained(model).to(self.device)
expected_model_name = "naver/splade-cocondenser-ensembledistil"
if model_dir:
if not self._is_correct_model(model_dir, expected_model_name):
self.tokenizer,self.model=self._download_model(model_dir, expected_model_name)
else:
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForMaskedLM.from_pretrained(model_dir).to(self.device)
else:
self.tokenizer = AutoTokenizer.from_pretrained(expected_model_name)
self.model = AutoModelForMaskedLM.from_pretrained(expected_model_name).to(self.device)
self.max_seq_length = max_seq_length

def _is_correct_model(self, model_dir, expected_model_name):
# Check for the presence of specific files that indicate the correct model
config_path = os.path.join(model_dir, 'config.json')
if not os.path.exists(config_path):
return False

with open(config_path, 'r') as config_file:
config = json.load(config_file)
return config.get("_name_or_path") == expected_model_name

def _download_model(self, model_dir, model_name):
# Ensure the directory exists
os.makedirs(model_dir, exist_ok=True)

# Download the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.save_pretrained(model_dir)

model = AutoModelForMaskedLM.from_pretrained(model_name)
model.save_pretrained(model_dir)
return tokenizer,model
def encode_documents(
self, texts: Union[str, List[str]]
) -> Union[SparseVector, List[SparseVector]]:
Expand Down