Skip to content

Commit

Permalink
Sequence classification task enabling (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
pi314ever authored Aug 26, 2024
1 parent 4334116 commit 1db605e
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 24 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Not all features of TEI are currently supported as this is still a work in progr
| Mistral | Embedding | <li>[intfloat/e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct)</li><li>[Salesforce/SFR-Embedding-2_R](https://huggingface.co/Salesforce/SFR-Embedding-2_R)</li> |
| GTE | Embedding | <li>[Alibaba-NLP/gte-large-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-large-en-v1.5)</li> |
| JinaBERT | Embedding | <li>[jinaai/jina-embeddings-v2-base-en](https://huggingface.co/jinaai/jina-embeddings-v2-base-en)</li> |
| Roberta | Sequence Classification | <li>[SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions)</li> |

> The license to use TEI on Habana Gaudi is the one of TEI: https://github.com/huggingface/text-embeddings-inference/blob/main/LICENSE
>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers.models.bert import BertConfig
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES

from text_embeddings_server.models.model import Model
from text_embeddings_server.models.model import Model, B
from text_embeddings_server.models.default_model import DefaultModel
from text_embeddings_server.models.classification_model import ClassificationModel

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

from text_embeddings_server.models import Model
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score
from text_embeddings_server.models.types import PaddedBatch, Score

tracer = trace.get_tracer(__name__)


class ClassificationModel(Model):
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
if device == torch.device("hpu"):
Expand Down Expand Up @@ -46,15 +47,17 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
is not None
)

super(ClassificationModel, self).__init__(model=model, dtype=dtype, device=device)
super(ClassificationModel, self).__init__(
model=model, dtype=dtype, device=device
)

@property
def batch_type(self) -> Type[PaddedBatch]:
return PaddedBatch

@tracer.start_as_current_span("embed")
def embed(self, batch: PaddedBatch) -> List[Embedding]:
pass
def embed(self, batch):
raise NotImplementedError(f"Embed is not a valid operation for model type {self.model.config.model_type}")

@tracer.start_as_current_span("predict")
def predict(self, batch: PaddedBatch) -> List[Score]:
Expand All @@ -65,10 +68,5 @@ def predict(self, batch: PaddedBatch) -> List[Score]:
kwargs["position_ids"] = batch.position_ids

output = self.model(**kwargs, return_dict=True)
scores = output.logits.view(-1, ).tolist()
return [
Score(
values=scores[i:i+1]
)
for i in range(len(batch))
]
all_scores = output.logits.tolist()
return [Score(values=scores) for scores in all_scores]
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

from text_embeddings_server.models import Model
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score
from text_embeddings_server.models.types import PaddedBatch, Embedding

tracer = trace.get_tracer(__name__)

Expand Down Expand Up @@ -74,5 +74,5 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
]

@tracer.start_as_current_span("predict")
def predict(self, batch: PaddedBatch) -> List[Score]:
pass
def predict(self, batch):
raise NotImplementedError(f"Predict is not a valid operation for model type {self.model.config.model_type}")
4 changes: 2 additions & 2 deletions backends/python/server/text_embeddings_server/models/model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import torch

from abc import ABC, abstractmethod
from typing import List, TypeVar, Type
from typing import List, TypeVar, Type, Generic

from text_embeddings_server.models.types import Batch, Embedding, Score

B = TypeVar("B", bound=Batch)


class Model(ABC):
class Model(ABC, Generic[B]):
def __init__(
self,
model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def round_up(number, k):
class Batch(ABC):
@classmethod
@abstractmethod
def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "Batch":
def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device, *args, **kwargs) -> "Batch":
raise NotImplementedError

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions backends/python/server/text_embeddings_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from pathlib import Path
from typing import Optional

from text_embeddings_server.models import Model, get_model
from text_embeddings_server.models import Model, get_model, B
from text_embeddings_server.pb import embed_pb2_grpc, embed_pb2
from text_embeddings_server.utils.tracing import UDSOpenTelemetryAioServerInterceptor
from text_embeddings_server.utils.interceptor import ExceptionInterceptor


class EmbeddingService(embed_pb2_grpc.EmbeddingServiceServicer):
def __init__(self, model: Model):
def __init__(self, model: Model[B]):
self.model = model
# Force inference mode for the lifetime of EmbeddingService
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
Expand Down
5 changes: 1 addition & 4 deletions backends/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,11 @@ impl PythonBackend {
otlp_service_name: String,
) -> Result<Self, BackendError> {
match model_type {
ModelType::Classifier => {
None
}
ModelType::Classifier => {}
ModelType::Embedding(pool) => {
if pool != Pool::Cls {
return Err(BackendError::Start(format!("{pool:?} is not supported")));
}
Some(pool)
}
};

Expand Down

0 comments on commit 1db605e

Please sign in to comment.