From a2aedfb3acf7678bcfa4210699e0281ac4bdc033 Mon Sep 17 00:00:00 2001 From: pengjunfeng11 <179464367@qq.com> Date: Sun, 10 Nov 2024 18:46:19 +0800 Subject: [PATCH 1/7] =?UTF-8?q?=E6=94=AF=E6=8C=81=E4=BD=BF=E7=94=A8bge-m3?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E7=94=9F=E6=88=90=E7=A8=80=E7=96=8F=E5=90=91?= =?UTF-8?q?=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- xinference/model/embedding/core.py | 86 ++++++++++++++++++++++++------ xinference/types.py | 3 +- 2 files changed, 72 insertions(+), 17 deletions(-) diff --git a/xinference/model/embedding/core.py b/xinference/model/embedding/core.py index ae66b945b2..70ab51ab53 100644 --- a/xinference/model/embedding/core.py +++ b/xinference/model/embedding/core.py @@ -193,6 +193,24 @@ def to(self, *args, **kwargs): device=self._device, model_kwargs=model_kwargs, ) + elif self._kwargs.get("hybird_mode") and "m3" in self._model_spec.model_name.lower(): + try: + from FlagEmbedding import BGEM3FlagModel + except ImportError: + error_message = "Failed to import module 'BGEM3FlagModel'" + installation_guide = [ + "Please make sure 'FlagEmbedding' is installed. ", + "You can install it by `pip install FlagEmbedding`\n", + ] + raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") + + model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None + self._model = BGEM3FlagModel( + self._model_path, + device=self._device, + model_kwargs=model_kwargs, + trust_remote_code=True, + ) else: model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None self._model = SentenceTransformer( @@ -204,13 +222,14 @@ def to(self, *args, **kwargs): def create_embedding(self, sentences: Union[str, List[str]], **kwargs): from sentence_transformers import SentenceTransformer + from FlagEmbedding import BGEM3FlagModel kwargs.setdefault("normalize_embeddings", True) # copied from sentence-transformers, and modify it to return tokens num @no_type_check def encode( - model: SentenceTransformer, + model: Union[SentenceTransformer, BGEM3FlagModel], sentences: Union[str, List[str]], prompt_name: Optional[str] = None, prompt: Optional[str] = None, @@ -242,7 +261,8 @@ def encode( from sentence_transformers.util import batch_to_device from tqdm.autonotebook import trange - model.eval() + if not isinstance(model, BGEM3FlagModel): + model.eval() if show_progress_bar is None: show_progress_bar = ( logger.getEffectiveLevel() == logging.INFO @@ -271,7 +291,7 @@ def encode( raise ValueError( f"Prompt name '{prompt_name}' not found in the configured prompts dictionary with keys {list(model.prompts.keys())!r}." ) - elif model.default_prompt_name is not None: + elif not isinstance(model, BGEM3FlagModel) and model.default_prompt_name is not None: prompt = model.prompts.get(model.default_prompt_name, None) else: if prompt_name is not None: @@ -293,7 +313,10 @@ def encode( ) if device is None: - device = model._target_device + # same as SentenceTransformer.py + from sentence_transformers.util import get_device_name + device = get_device_name() + logger.info(f"Use pytorch device_name: {device}") if ( "gte" in self._model_spec.model_name.lower() @@ -303,8 +326,20 @@ def encode( all_embeddings = [] all_token_nums = 0 + + # 原有的写法不支持其他推理引擎 + def _text_length(text): + if isinstance(text, dict): # {key: value} case + return len(next(iter(text.values()))) + elif not hasattr(text, "__len__"): # Object has no len() method + return 1 + elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints + return len(text) + else: + return sum([len(t) for t in text]) # Sum of length of individual strings + length_sorted_idx = np.argsort( - [-model._text_length(sen) for sen in sentences] + [-_text_length(sen) for sen in sentences] ) sentences_sorted = [sentences[idx] for idx in length_sorted_idx] @@ -318,15 +353,21 @@ def encode( sentences_batch = sentences_sorted[ start_index : start_index + batch_size ] - features = model.tokenize(sentences_batch) - features = batch_to_device(features, device) - features.update(extra_features) - # when batching, the attention mask 1 means there is a token - # thus we just sum up it to get the total number of tokens - all_token_nums += features["attention_mask"].sum().item() + if not isinstance(model, BGEM3FlagModel): + features = model.tokenize(sentences_batch) + features = batch_to_device(features, device) + features.update(extra_features) + # when batching, the attention mask 1 means there is a token + # thus we just sum up it to get the total number of tokens + all_token_nums += features["attention_mask"].sum().item() with torch.no_grad(): - out_features = model.forward(features, **kwargs) + # if use hybird mode and setting return_sparse==true, return sparse embedding + # only support bge-m3 model now + if isinstance(model, BGEM3FlagModel): + out_features = model.encode(sentences_batch, **kwargs) + else: + out_features = model.forward(features, **kwargs) if output_value == "token_embeddings": embeddings = [] @@ -348,6 +389,14 @@ def encode( for name in out_features } embeddings.append(row) + # for sparse embedding + elif output_value == "sentence_embedding" and isinstance(model, BGEM3FlagModel): + if kwargs.get("return_sparse"): + embeddings = out_features['lexical_weights'] + else: + embeddings = out_features['dense_vecs'] + if convert_to_numpy: + embeddings = embeddings.cpu() else: # Sentence embeddings embeddings = out_features[output_value] embeddings = embeddings.detach() @@ -401,14 +450,19 @@ def encode( all_embeddings = [all_embeddings] embedding_list = [] for index, data in enumerate(all_embeddings): - embedding_list.append( - EmbeddingData(index=index, object="embedding", embedding=data.tolist()) - ) + if kwargs.get("return_sparse") and isinstance(self._model, BGEM3FlagModel): + embedding_list.append( + EmbeddingData(index=index, object="embedding", embedding={k: float(v) for k, v in data.items()}) + ) + else: + embedding_list.append( + EmbeddingData(index=index, object="embedding", embedding=data.tolist()) + ) usage = EmbeddingUsage( prompt_tokens=all_token_nums, total_tokens=all_token_nums ) result = Embedding( - object="list", + object="list" if not isinstance(self._model, BGEM3FlagModel) and not kwargs.get("return_sparse") else "dict", model=self._model_uid, data=embedding_list, usage=usage, diff --git a/xinference/types.py b/xinference/types.py index 613d8709bb..27895edcd4 100644 --- a/xinference/types.py +++ b/xinference/types.py @@ -71,7 +71,8 @@ class EmbeddingUsage(TypedDict): class EmbeddingData(TypedDict): index: int object: str - embedding: List[float] + # support sparse embedding + embedding: List[float] | Dict[str, float] class Embedding(TypedDict): From 46196d04e95a95f6245cc2b08add54c96c37dbbe Mon Sep 17 00:00:00 2001 From: pengjunfeng11 <179464367@qq.com> Date: Mon, 11 Nov 2024 15:50:19 +0800 Subject: [PATCH 2/7] =?UTF-8?q?=E6=96=B0=E5=A2=9Eembedding=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E9=80=9A=E8=BF=87token=20id=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E4=B8=BA=E5=AD=97=E7=AC=A6=E7=9A=84=E6=96=B9=E6=B3=95create=5F?= =?UTF-8?q?embedding?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- xinference/api/restful_api.py | 44 +++++++++++++++++++++ xinference/client/restful/restful_client.py | 35 ++++++++++++++++ xinference/core/model.py | 13 ++++++ xinference/model/embedding/core.py | 17 ++++++++ 4 files changed, 109 insertions(+) diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index ed3a2eab90..8780c4bd01 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -484,6 +484,16 @@ async def internal_exception_handler(request: Request, exc: Exception): else None ), ) + self._router.add_api_route( + "/v1/convert_ids_to_tokens", + self.convert_ids_to_tokens, + methods=["POST"], + dependencies=( + [Security(self._auth_service, scopes=["models:read"])] + if self.is_authenticated() + else None + ), + ) self._router.add_api_route( "/v1/rerank", self.rerank, @@ -1306,6 +1316,40 @@ async def create_embedding(self, request: Request) -> Response: logger.error(e, exc_info=True) await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) + async def convert_ids_to_tokens(self, request: Request) -> Response: + payload = await request.json() + body = CreateEmbeddingRequest.parse_obj(payload) + model_uid = body.model + exclude = { + "model", + "input", + "user", + } + kwargs = {key: value for key, value in payload.items() if key not in exclude} + + try: + model = await (await self._get_supervisor_ref()).get_model(model_uid) + except ValueError as ve: + logger.error(str(ve), exc_info=True) + await self._report_error_event(model_uid, str(ve)) + raise HTTPException(status_code=400, detail=str(ve)) + except Exception as e: + logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) + raise HTTPException(status_code=500, detail=str(e)) + + try: + decoded_texts = await model.convert_ids_to_tokens(body.input, **kwargs) + return Response(decoded_texts, media_type="application/json") + except RuntimeError as re: + logger.error(re, exc_info=True) + await self._report_error_event(model_uid, str(re)) + self.handle_request_limit_error(re) + raise HTTPException(status_code=400, detail=str(re)) + except Exception as e: + logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) + raise HTTPException(status_code=500, detail=str(e)) async def rerank(self, request: Request) -> Response: payload = await request.json() diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index dd5e3f1146..77798b18c0 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -125,6 +125,41 @@ def create_embedding(self, input: Union[str, List[str]], **kwargs) -> "Embedding response_data = response.json() return response_data + + def convert_ids_to_tokens(self, input: Union[List, List[List]], **kwargs) -> List[str]: + """ + Convert token IDs to human readable tokens via RESTful APIs. + + Parameters + ---------- + input: Union[List, List[List]] + Input token IDs to convert, can be a single list of token IDs or a list of token ID lists. + To convert multiple sequences in a single request, pass a list of token ID lists. + + Returns + ------- + list + A list of decoded tokens in human readable format. + + Raises + ------ + RuntimeError + Report the failure of token conversion and provide the error message. + + """ + url = f"{self._base_url}/v1/convert_ids_to_tokens" + request_body = { + "model": self._model_uid, + "input": input, + } + request_body.update(kwargs) + response = requests.post(url, json=request_body, headers=self.auth_headers) + if response.status_code != 200: + raise RuntimeError( + f"Failed to decode token ids, detail: {_get_error_string(response)}" + ) + response_data = response.json() + return response_data class RESTfulRerankModelHandle(RESTfulModelHandle): diff --git a/xinference/core/model.py b/xinference/core/model.py index e911c71e6d..4f5df50ca8 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -786,6 +786,19 @@ async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs): raise AttributeError( f"Model {self._model.model_spec} is not for creating embedding." ) + + @request_limit + @log_async(logger=logger) + async def convert_ids_to_tokens(self, input: Union[List, List[List]], *args, **kwargs): + kwargs.pop("request_id", None) + if hasattr(self._model, "convert_ids_to_tokens"): + return await self._call_wrapper_json( + self._model.convert_ids_to_tokens, input, *args, **kwargs + ) + + raise AttributeError( + f"Model {self._model.model_spec} can convert token id." + ) @request_limit @log_async(logger=logger) diff --git a/xinference/model/embedding/core.py b/xinference/model/embedding/core.py index 70ab51ab53..731995c111 100644 --- a/xinference/model/embedding/core.py +++ b/xinference/model/embedding/core.py @@ -484,6 +484,23 @@ def _text_length(text): return result + def convert_ids_to_tokens(self, token_ids: Union[List, List[List]], **kwargs): + from FlagEmbedding import BGEM3FlagModel + from sentence_transformers import SentenceTransformer + + decoded_texts = [] + if isinstance(token_ids): + for idx, snetence_token_ids in enumerate(token_ids): + decoded_texts.append(self._model.tokenizer.decode(snetence_token_ids)) + else: + decoded_texts = self._model.tokenizer.decode(token_ids) + # if isinstance(self._model, BGEM3FlagModel): + # pass + + # if isinstance(self._model, SentenceTransformer): + # sentence = self._model.tokenizer.decode(token_ids) + + def match_embedding( model_name: str, From 94b0cb70c777cb6d553e34bd0bf9c1d65ddf69b7 Mon Sep 17 00:00:00 2001 From: pengjunfeng11 <179464367@qq.com> Date: Mon, 11 Nov 2024 17:12:31 +0800 Subject: [PATCH 3/7] =?UTF-8?q?=E8=B0=83=E7=94=A8create=5Fembedding?= =?UTF-8?q?=E6=97=B6=E6=A0=B9=E6=8D=AE=E5=8F=82=E6=95=B0=E5=88=A4=E6=96=AD?= =?UTF-8?q?=E6=98=AF=E5=90=A6=E8=B0=83=E7=94=A8flagEmbedding?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- xinference/model/embedding/core.py | 100 ++++++++++++++++++++--------- 1 file changed, 71 insertions(+), 29 deletions(-) diff --git a/xinference/model/embedding/core.py b/xinference/model/embedding/core.py index 731995c111..fc772ad91a 100644 --- a/xinference/model/embedding/core.py +++ b/xinference/model/embedding/core.py @@ -193,7 +193,10 @@ def to(self, *args, **kwargs): device=self._device, model_kwargs=model_kwargs, ) - elif self._kwargs.get("hybird_mode") and "m3" in self._model_spec.model_name.lower(): + elif ( + self._kwargs.get("hybird_mode") + and "m3" in self._model_spec.model_name.lower() + ): try: from FlagEmbedding import BGEM3FlagModel except ImportError: @@ -203,7 +206,7 @@ def to(self, *args, **kwargs): "You can install it by `pip install FlagEmbedding`\n", ] raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") - + model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None self._model = BGEM3FlagModel( self._model_path, @@ -221,11 +224,15 @@ def to(self, *args, **kwargs): ) def create_embedding(self, sentences: Union[str, List[str]], **kwargs): - from sentence_transformers import SentenceTransformer from FlagEmbedding import BGEM3FlagModel + from sentence_transformers import SentenceTransformer kwargs.setdefault("normalize_embeddings", True) + if kwargs.get("return_sparse") and "m3" in self._model_spec.model_name.lower(): + self._kwargs["return_sparse"] = True + self.load() + # copied from sentence-transformers, and modify it to return tokens num @no_type_check def encode( @@ -291,7 +298,10 @@ def encode( raise ValueError( f"Prompt name '{prompt_name}' not found in the configured prompts dictionary with keys {list(model.prompts.keys())!r}." ) - elif not isinstance(model, BGEM3FlagModel) and model.default_prompt_name is not None: + elif ( + not isinstance(model, BGEM3FlagModel) + and model.default_prompt_name is not None + ): prompt = model.prompts.get(model.default_prompt_name, None) else: if prompt_name is not None: @@ -315,6 +325,7 @@ def encode( if device is None: # same as SentenceTransformer.py from sentence_transformers.util import get_device_name + device = get_device_name() logger.info(f"Use pytorch device_name: {device}") @@ -333,14 +344,16 @@ def _text_length(text): return len(next(iter(text.values()))) elif not hasattr(text, "__len__"): # Object has no len() method return 1 - elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints + elif len(text) == 0 or isinstance( + text[0], int + ): # Empty string or list of ints return len(text) else: - return sum([len(t) for t in text]) # Sum of length of individual strings - - length_sorted_idx = np.argsort( - [-_text_length(sen) for sen in sentences] - ) + return sum( + [len(t) for t in text] + ) # Sum of length of individual strings + + length_sorted_idx = np.argsort([-_text_length(sen) for sen in sentences]) sentences_sorted = [sentences[idx] for idx in length_sorted_idx] for start_index in trange( @@ -390,11 +403,13 @@ def _text_length(text): } embeddings.append(row) # for sparse embedding - elif output_value == "sentence_embedding" and isinstance(model, BGEM3FlagModel): + elif output_value == "sentence_embedding" and isinstance( + model, BGEM3FlagModel + ): if kwargs.get("return_sparse"): - embeddings = out_features['lexical_weights'] + embeddings = out_features["lexical_weights"] else: - embeddings = out_features['dense_vecs'] + embeddings = out_features["dense_vecs"] if convert_to_numpy: embeddings = embeddings.cpu() else: # Sentence embeddings @@ -452,17 +467,26 @@ def _text_length(text): for index, data in enumerate(all_embeddings): if kwargs.get("return_sparse") and isinstance(self._model, BGEM3FlagModel): embedding_list.append( - EmbeddingData(index=index, object="embedding", embedding={k: float(v) for k, v in data.items()}) + EmbeddingData( + index=index, + object="embedding", + embedding={k: float(v) for k, v in data.items()}, + ) ) else: embedding_list.append( - EmbeddingData(index=index, object="embedding", embedding=data.tolist()) + EmbeddingData( + index=index, object="embedding", embedding=data.tolist() + ) ) usage = EmbeddingUsage( prompt_tokens=all_token_nums, total_tokens=all_token_nums ) result = Embedding( - object="list" if not isinstance(self._model, BGEM3FlagModel) and not kwargs.get("return_sparse") else "dict", + object="list" + if not isinstance(self._model, BGEM3FlagModel) + and not kwargs.get("return_sparse") + else "dict", model=self._model_uid, data=embedding_list, usage=usage, @@ -484,22 +508,40 @@ def _text_length(text): return result - def convert_ids_to_tokens(self, token_ids: Union[List, List[List]], **kwargs): - from FlagEmbedding import BGEM3FlagModel - from sentence_transformers import SentenceTransformer + def convert_ids_to_tokens( + self, + batch_token_ids: Union[List[Union[int, str]], List[List[Union[int, str]]]], + **kwargs, + ) -> Union[List[str]]: + batch_decoded_texts: List[str] = [] - decoded_texts = [] - if isinstance(token_ids): - for idx, snetence_token_ids in enumerate(token_ids): - decoded_texts.append(self._model.tokenizer.decode(snetence_token_ids)) - else: - decoded_texts = self._model.tokenizer.decode(token_ids) - # if isinstance(self._model, BGEM3FlagModel): - # pass + if self._model is None: + self.load() + assert self._model is not None - # if isinstance(self._model, SentenceTransformer): - # sentence = self._model.tokenizer.decode(token_ids) + if isinstance(batch_token_ids, (int, str)): + return self._model.tokenizer.convert_ids_to_tokens( + [int(str(batch_token_ids))] + )[0] + # check if it's a nested list + if ( + isinstance(batch_token_ids, list) + and batch_token_ids + and isinstance(batch_token_ids[0], list) + ): + for token_ids in batch_token_ids: + token_ids = [int(token_id) for token_id in token_ids] + batch_decoded_texts.append( + self._model.tokenizer.convert_ids_to_tokens(token_ids) + ) + else: + batch_token_ids = [int(token_id) for token_id in batch_token_ids] + batch_decoded_texts = self._model.tokenizer.convert_ids_to_tokens( + batch_token_ids + ) + print(batch_decoded_texts) + return batch_decoded_texts def match_embedding( From fa8ae35eb462a8da8f2799f854b537fdc6f82e64 Mon Sep 17 00:00:00 2001 From: pengjunfeng11 <179464367@qq.com> Date: Mon, 11 Nov 2024 17:24:34 +0800 Subject: [PATCH 4/7] sparse vector support --- xinference/api/restful_api.py | 1 + xinference/client/restful/restful_client.py | 6 ++++-- xinference/core/model.py | 10 +++++----- xinference/model/embedding/core.py | 10 ++++++---- xinference/types.py | 2 +- 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index 8780c4bd01..b7c732af3b 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -1316,6 +1316,7 @@ async def create_embedding(self, request: Request) -> Response: logger.error(e, exc_info=True) await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) + async def convert_ids_to_tokens(self, request: Request) -> Response: payload = await request.json() body = CreateEmbeddingRequest.parse_obj(payload) diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index 77798b18c0..bc2aa11fcd 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -125,8 +125,10 @@ def create_embedding(self, input: Union[str, List[str]], **kwargs) -> "Embedding response_data = response.json() return response_data - - def convert_ids_to_tokens(self, input: Union[List, List[List]], **kwargs) -> List[str]: + + def convert_ids_to_tokens( + self, input: Union[List, List[List]], **kwargs + ) -> List[str]: """ Convert token IDs to human readable tokens via RESTful APIs. diff --git a/xinference/core/model.py b/xinference/core/model.py index 4f5df50ca8..af0c396fda 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -786,19 +786,19 @@ async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs): raise AttributeError( f"Model {self._model.model_spec} is not for creating embedding." ) - + @request_limit @log_async(logger=logger) - async def convert_ids_to_tokens(self, input: Union[List, List[List]], *args, **kwargs): + async def convert_ids_to_tokens( + self, input: Union[List, List[List]], *args, **kwargs + ): kwargs.pop("request_id", None) if hasattr(self._model, "convert_ids_to_tokens"): return await self._call_wrapper_json( self._model.convert_ids_to_tokens, input, *args, **kwargs ) - raise AttributeError( - f"Model {self._model.model_spec} can convert token id." - ) + raise AttributeError(f"Model {self._model.model_spec} can convert token id.") @request_limit @log_async(logger=logger) diff --git a/xinference/model/embedding/core.py b/xinference/model/embedding/core.py index fc772ad91a..c72194a25e 100644 --- a/xinference/model/embedding/core.py +++ b/xinference/model/embedding/core.py @@ -483,10 +483,12 @@ def _text_length(text): prompt_tokens=all_token_nums, total_tokens=all_token_nums ) result = Embedding( - object="list" - if not isinstance(self._model, BGEM3FlagModel) - and not kwargs.get("return_sparse") - else "dict", + object=( + "list" # type: ignore + if not isinstance(self._model, BGEM3FlagModel) + and not kwargs.get("return_sparse") + else "dict" + ), model=self._model_uid, data=embedding_list, usage=usage, diff --git a/xinference/types.py b/xinference/types.py index 27895edcd4..759cf0b7c4 100644 --- a/xinference/types.py +++ b/xinference/types.py @@ -72,7 +72,7 @@ class EmbeddingData(TypedDict): index: int object: str # support sparse embedding - embedding: List[float] | Dict[str, float] + embedding: Union[List[float], Dict[str, float]] class Embedding(TypedDict): From da0395888b200d7c55d016efd6fdb9f436629c83 Mon Sep 17 00:00:00 2001 From: pengjunfeng11 <179464367@qq.com> Date: Mon, 11 Nov 2024 20:43:21 +0800 Subject: [PATCH 5/7] FEATURE: bge-m3 embedding model genarate sparse vector support --- xinference/model/embedding/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xinference/model/embedding/core.py b/xinference/model/embedding/core.py index c72194a25e..3feb35fa41 100644 --- a/xinference/model/embedding/core.py +++ b/xinference/model/embedding/core.py @@ -194,7 +194,7 @@ def to(self, *args, **kwargs): model_kwargs=model_kwargs, ) elif ( - self._kwargs.get("hybird_mode") + self._kwargs.get("hybrid_mode") and "m3" in self._model_spec.model_name.lower() ): try: @@ -230,7 +230,7 @@ def create_embedding(self, sentences: Union[str, List[str]], **kwargs): kwargs.setdefault("normalize_embeddings", True) if kwargs.get("return_sparse") and "m3" in self._model_spec.model_name.lower(): - self._kwargs["return_sparse"] = True + self._kwargs["hybrid_mode"] = True self.load() # copied from sentence-transformers, and modify it to return tokens num @@ -375,7 +375,7 @@ def _text_length(text): all_token_nums += features["attention_mask"].sum().item() with torch.no_grad(): - # if use hybird mode and setting return_sparse==true, return sparse embedding + # if use hybrid mode and setting return_sparse==true, return sparse embedding # only support bge-m3 model now if isinstance(model, BGEM3FlagModel): out_features = model.encode(sentences_batch, **kwargs) From 78735d0c011a728c208b900b7726e1b9e2338263 Mon Sep 17 00:00:00 2001 From: pengjunfeng11 <34857167+pengjunfeng11@users.noreply.github.com> Date: Fri, 15 Nov 2024 13:31:02 +0800 Subject: [PATCH 6/7] Update core.py --- xinference/model/embedding/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xinference/model/embedding/core.py b/xinference/model/embedding/core.py index 3feb35fa41..7868793ea0 100644 --- a/xinference/model/embedding/core.py +++ b/xinference/model/embedding/core.py @@ -338,7 +338,7 @@ def encode( all_embeddings = [] all_token_nums = 0 - # 原有的写法不支持其他推理引擎 + # The original code does not support other inference engines def _text_length(text): if isinstance(text, dict): # {key: value} case return len(next(iter(text.values()))) From 7e113bb54be4866a8df0942814515390d0994a4d Mon Sep 17 00:00:00 2001 From: pengjunfeng11 <34857167+pengjunfeng11@users.noreply.github.com> Date: Fri, 15 Nov 2024 14:33:43 +0800 Subject: [PATCH 7/7] Update core.py --- xinference/model/embedding/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xinference/model/embedding/core.py b/xinference/model/embedding/core.py index 7868793ea0..34a6ebe7af 100644 --- a/xinference/model/embedding/core.py +++ b/xinference/model/embedding/core.py @@ -542,7 +542,6 @@ def convert_ids_to_tokens( batch_decoded_texts = self._model.tokenizer.convert_ids_to_tokens( batch_token_ids ) - print(batch_decoded_texts) return batch_decoded_texts