Skip to content

Commit

Permalink
ENH: add normalize to rerank model (#2509)
Browse files Browse the repository at this point in the history
Co-authored-by: libing <[email protected]>
Co-authored-by: codingl2k1 <[email protected]>
  • Loading branch information
3 people authored Nov 15, 2024
1 parent 042eb5b commit 7a0bb60
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 13 deletions.
12 changes: 6 additions & 6 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class RerankRequest(BaseModel):
return_documents: Optional[bool] = False
return_len: Optional[bool] = False
max_chunks_per_doc: Optional[int] = None
kwargs: Optional[str] = None


class TextToImageRequest(BaseModel):
Expand Down Expand Up @@ -1315,11 +1316,6 @@ async def rerank(self, request: Request) -> Response:
payload = await request.json()
body = RerankRequest.parse_obj(payload)
model_uid = body.model
kwargs = {
key: value
for key, value in payload.items()
if key not in RerankRequest.__annotations__.keys()
}

try:
model = await (await self._get_supervisor_ref()).get_model(model_uid)
Expand All @@ -1333,14 +1329,18 @@ async def rerank(self, request: Request) -> Response:
raise HTTPException(status_code=500, detail=str(e))

try:
if body.kwargs is not None:
parsed_kwargs = json.loads(body.kwargs)
else:
parsed_kwargs = {}
scores = await model.rerank(
body.documents,
body.query,
top_n=body.top_n,
max_chunks_per_doc=body.max_chunks_per_doc,
return_documents=body.return_documents,
return_len=body.return_len,
**kwargs,
**parsed_kwargs,
)
return Response(scores, media_type="application/json")
except RuntimeError as re:
Expand Down
1 change: 1 addition & 0 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def rerank(
"max_chunks_per_doc": max_chunks_per_doc,
"return_documents": return_documents,
"return_len": return_len,
"kwargs": json.dumps(kwargs),
}
request_body.update(kwargs)
response = requests.post(url, json=request_body, headers=self.auth_headers)
Expand Down
15 changes: 11 additions & 4 deletions xinference/model/rerank/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def _auto_detect_type(model_path):
return rerank_type

def load(self):
logger.info("Loading rerank model: %s", self._model_path)
flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
if (
self._auto_detect_type(self._model_path) != "normal"
Expand All @@ -189,6 +190,7 @@ def load(self):
"will force set `use_fp16` to True"
)
self._use_fp16 = True

if self._model_spec.type == "normal":
try:
import sentence_transformers
Expand Down Expand Up @@ -250,22 +252,27 @@ def rerank(
**kwargs,
) -> Rerank:
assert self._model is not None
if kwargs:
raise ValueError("rerank hasn't support extra parameter.")
if max_chunks_per_doc is not None:
raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
logger.info("Rerank with kwargs: %s, model: %s", kwargs, self._model)
sentence_combinations = [[query, doc] for doc in documents]
# reset n tokens
self._model.model.n_tokens = 0
if self._model_spec.type == "normal":
similarity_scores = self._model.predict(
sentence_combinations, convert_to_numpy=False, convert_to_tensor=True
sentence_combinations,
convert_to_numpy=False,
convert_to_tensor=True,
**kwargs,
).cpu()
if similarity_scores.dtype == torch.bfloat16:
similarity_scores = similarity_scores.float()
else:
# Related issue: https://github.com/xorbitsai/inference/issues/1775
similarity_scores = self._model.compute_score(sentence_combinations)
similarity_scores = self._model.compute_score(
sentence_combinations, **kwargs
)

if not isinstance(similarity_scores, Sequence):
similarity_scores = [similarity_scores]
elif (
Expand Down
5 changes: 2 additions & 3 deletions xinference/model/rerank/tests/test_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,8 @@ def test_restful_api(model_name, setup):
kwargs = {
"invalid": "invalid",
}
with pytest.raises(RuntimeError) as err:
scores = model.rerank(corpus, query, **kwargs)
assert "hasn't support" in str(err.value)
with pytest.raises(RuntimeError):
model.rerank(corpus, query, **kwargs)


def test_from_local_uri():
Expand Down

0 comments on commit 7a0bb60

Please sign in to comment.