Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: support sparse vector for bge-m3 #2540

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,16 @@ async def internal_exception_handler(request: Request, exc: Exception):
else None
),
)
self._router.add_api_route(
"/v1/convert_ids_to_tokens",
self.convert_ids_to_tokens,
methods=["POST"],
dependencies=(
[Security(self._auth_service, scopes=["models:read"])]
if self.is_authenticated()
else None
),
)
self._router.add_api_route(
"/v1/rerank",
self.rerank,
Expand Down Expand Up @@ -1307,6 +1317,41 @@ async def create_embedding(self, request: Request) -> Response:
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def convert_ids_to_tokens(self, request: Request) -> Response:
payload = await request.json()
body = CreateEmbeddingRequest.parse_obj(payload)
model_uid = body.model
exclude = {
"model",
"input",
"user",
}
kwargs = {key: value for key, value in payload.items() if key not in exclude}

try:
model = await (await self._get_supervisor_ref()).get_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
await self._report_error_event(model_uid, str(ve))
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

try:
decoded_texts = await model.convert_ids_to_tokens(body.input, **kwargs)
return Response(decoded_texts, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
self.handle_request_limit_error(re)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def rerank(self, request: Request) -> Response:
payload = await request.json()
body = RerankRequest.parse_obj(payload)
Expand Down
37 changes: 37 additions & 0 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,43 @@ def create_embedding(self, input: Union[str, List[str]], **kwargs) -> "Embedding
response_data = response.json()
return response_data

def convert_ids_to_tokens(
self, input: Union[List, List[List]], **kwargs
) -> List[str]:
"""
Convert token IDs to human readable tokens via RESTful APIs.

Parameters
----------
input: Union[List, List[List]]
Input token IDs to convert, can be a single list of token IDs or a list of token ID lists.
To convert multiple sequences in a single request, pass a list of token ID lists.

Returns
-------
list
A list of decoded tokens in human readable format.

Raises
------
RuntimeError
Report the failure of token conversion and provide the error message.

"""
url = f"{self._base_url}/v1/convert_ids_to_tokens"
request_body = {
"model": self._model_uid,
"input": input,
}
request_body.update(kwargs)
response = requests.post(url, json=request_body, headers=self.auth_headers)
if response.status_code != 200:
raise RuntimeError(
f"Failed to decode token ids, detail: {_get_error_string(response)}"
)
response_data = response.json()
return response_data


class RESTfulRerankModelHandle(RESTfulModelHandle):
def rerank(
Expand Down
13 changes: 13 additions & 0 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,19 @@ async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
f"Model {self._model.model_spec} is not for creating embedding."
)

@request_limit
@log_async(logger=logger)
async def convert_ids_to_tokens(
self, input: Union[List, List[List]], *args, **kwargs
):
kwargs.pop("request_id", None)
if hasattr(self._model, "convert_ids_to_tokens"):
return await self._call_wrapper_json(
self._model.convert_ids_to_tokens, input, *args, **kwargs
)

raise AttributeError(f"Model {self._model.model_spec} can convert token id.")

@request_limit
@log_async(logger=logger)
async def rerank(
Expand Down
151 changes: 133 additions & 18 deletions xinference/model/embedding/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,27 @@ def to(self, *args, **kwargs):
device=self._device,
model_kwargs=model_kwargs,
)
elif (
self._kwargs.get("hybrid_mode")
and "m3" in self._model_spec.model_name.lower()
):
try:
from FlagEmbedding import BGEM3FlagModel
except ImportError:
error_message = "Failed to import module 'BGEM3FlagModel'"
installation_guide = [
"Please make sure 'FlagEmbedding' is installed. ",
"You can install it by `pip install FlagEmbedding`\n",
]
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")

model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None
self._model = BGEM3FlagModel(
self._model_path,
device=self._device,
model_kwargs=model_kwargs,
trust_remote_code=True,
)
else:
model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None
self._model = SentenceTransformer(
Expand All @@ -203,14 +224,19 @@ def to(self, *args, **kwargs):
)

def create_embedding(self, sentences: Union[str, List[str]], **kwargs):
from FlagEmbedding import BGEM3FlagModel
from sentence_transformers import SentenceTransformer

kwargs.setdefault("normalize_embeddings", True)

if kwargs.get("return_sparse") and "m3" in self._model_spec.model_name.lower():
self._kwargs["hybrid_mode"] = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit disruptive to the design, I don't know if there is a more elegant way.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you referring only to the if judgment part or the subsequent reload part?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the reload part.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about loading bge-m3 when specifying hybrid_mode=True? This can be done in load.

self.load()

# copied from sentence-transformers, and modify it to return tokens num
@no_type_check
def encode(
model: SentenceTransformer,
model: Union[SentenceTransformer, BGEM3FlagModel],
sentences: Union[str, List[str]],
prompt_name: Optional[str] = None,
prompt: Optional[str] = None,
Expand Down Expand Up @@ -242,7 +268,8 @@ def encode(
from sentence_transformers.util import batch_to_device
from tqdm.autonotebook import trange

model.eval()
if not isinstance(model, BGEM3FlagModel):
model.eval()
if show_progress_bar is None:
show_progress_bar = (
logger.getEffectiveLevel() == logging.INFO
Expand Down Expand Up @@ -271,7 +298,10 @@ def encode(
raise ValueError(
f"Prompt name '{prompt_name}' not found in the configured prompts dictionary with keys {list(model.prompts.keys())!r}."
)
elif model.default_prompt_name is not None:
elif (
not isinstance(model, BGEM3FlagModel)
and model.default_prompt_name is not None
):
prompt = model.prompts.get(model.default_prompt_name, None)
else:
if prompt_name is not None:
Expand All @@ -293,7 +323,11 @@ def encode(
)

if device is None:
device = model._target_device
# same as SentenceTransformer.py
from sentence_transformers.util import get_device_name

device = get_device_name()
logger.info(f"Use pytorch device_name: {device}")

if (
"gte" in self._model_spec.model_name.lower()
Expand All @@ -303,9 +337,23 @@ def encode(

all_embeddings = []
all_token_nums = 0
length_sorted_idx = np.argsort(
[-model._text_length(sen) for sen in sentences]
)

# The original code does not support other inference engines
def _text_length(text):
if isinstance(text, dict): # {key: value} case
return len(next(iter(text.values())))
elif not hasattr(text, "__len__"): # Object has no len() method
return 1
elif len(text) == 0 or isinstance(
text[0], int
): # Empty string or list of ints
return len(text)
else:
return sum(
[len(t) for t in text]
) # Sum of length of individual strings

length_sorted_idx = np.argsort([-_text_length(sen) for sen in sentences])
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]

for start_index in trange(
Expand All @@ -318,15 +366,21 @@ def encode(
sentences_batch = sentences_sorted[
start_index : start_index + batch_size
]
features = model.tokenize(sentences_batch)
features = batch_to_device(features, device)
features.update(extra_features)
# when batching, the attention mask 1 means there is a token
# thus we just sum up it to get the total number of tokens
all_token_nums += features["attention_mask"].sum().item()
if not isinstance(model, BGEM3FlagModel):
features = model.tokenize(sentences_batch)
features = batch_to_device(features, device)
features.update(extra_features)
# when batching, the attention mask 1 means there is a token
# thus we just sum up it to get the total number of tokens
all_token_nums += features["attention_mask"].sum().item()

with torch.no_grad():
out_features = model.forward(features, **kwargs)
# if use hybrid mode and setting return_sparse==true, return sparse embedding
# only support bge-m3 model now
if isinstance(model, BGEM3FlagModel):
out_features = model.encode(sentences_batch, **kwargs)
else:
out_features = model.forward(features, **kwargs)

if output_value == "token_embeddings":
embeddings = []
Expand All @@ -348,6 +402,16 @@ def encode(
for name in out_features
}
embeddings.append(row)
# for sparse embedding
elif output_value == "sentence_embedding" and isinstance(
model, BGEM3FlagModel
):
if kwargs.get("return_sparse"):
embeddings = out_features["lexical_weights"]
else:
embeddings = out_features["dense_vecs"]
if convert_to_numpy:
embeddings = embeddings.cpu()
else: # Sentence embeddings
embeddings = out_features[output_value]
embeddings = embeddings.detach()
Expand Down Expand Up @@ -401,14 +465,30 @@ def encode(
all_embeddings = [all_embeddings]
embedding_list = []
for index, data in enumerate(all_embeddings):
embedding_list.append(
EmbeddingData(index=index, object="embedding", embedding=data.tolist())
)
if kwargs.get("return_sparse") and isinstance(self._model, BGEM3FlagModel):
embedding_list.append(
EmbeddingData(
index=index,
object="embedding",
embedding={k: float(v) for k, v in data.items()},
)
)
else:
embedding_list.append(
EmbeddingData(
index=index, object="embedding", embedding=data.tolist()
)
)
usage = EmbeddingUsage(
prompt_tokens=all_token_nums, total_tokens=all_token_nums
)
result = Embedding(
object="list",
object=(
"list" # type: ignore
if not isinstance(self._model, BGEM3FlagModel)
and not kwargs.get("return_sparse")
else "dict"
),
model=self._model_uid,
data=embedding_list,
usage=usage,
Expand All @@ -430,6 +510,41 @@ def encode(

return result

def convert_ids_to_tokens(
self,
batch_token_ids: Union[List[Union[int, str]], List[List[Union[int, str]]]],
**kwargs,
) -> Union[List[str]]:
batch_decoded_texts: List[str] = []

if self._model is None:
self.load()
assert self._model is not None

if isinstance(batch_token_ids, (int, str)):
return self._model.tokenizer.convert_ids_to_tokens(
[int(str(batch_token_ids))]
)[0]

# check if it's a nested list
if (
isinstance(batch_token_ids, list)
and batch_token_ids
and isinstance(batch_token_ids[0], list)
):
for token_ids in batch_token_ids:
token_ids = [int(token_id) for token_id in token_ids]
batch_decoded_texts.append(
self._model.tokenizer.convert_ids_to_tokens(token_ids)
)
else:
batch_token_ids = [int(token_id) for token_id in batch_token_ids]
batch_decoded_texts = self._model.tokenizer.convert_ids_to_tokens(
batch_token_ids
)
print(batch_decoded_texts)
return batch_decoded_texts


def match_embedding(
model_name: str,
Expand Down
3 changes: 2 additions & 1 deletion xinference/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class EmbeddingUsage(TypedDict):
class EmbeddingData(TypedDict):
index: int
object: str
embedding: List[float]
# support sparse embedding
embedding: Union[List[float], Dict[str, float]]


class Embedding(TypedDict):
Expand Down
Loading