Skip to content
This repository has been archived by the owner on Jan 29, 2024. It is now read-only.

Commit

Permalink
add api for classification
Browse files Browse the repository at this point in the history
  • Loading branch information
ixaxaar committed Oct 27, 2023
1 parent ad2670c commit c510066
Showing 1 changed file with 42 additions and 21 deletions.
63 changes: 42 additions & 21 deletions geniusrise_huggingface/classification/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,56 +14,77 @@
# limitations under the License.

import logging
from typing import Dict

from typing import Dict, Any
import torch
import cherrypy
from geniusrise import BatchInput, BatchOutput, State
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from geniusrise_huggingface.base import HuggingFaceAPI

log = logging.getLogger(__file__)


class HuggingFaceClassificationAPI(HuggingFaceAPI):
"""
A class for serving a Hugging Face-based classification model.
Args:
input (BatchInput): The input data.
output (BatchOutput): The output data.
state (State): The state data.
**kwargs: Additional keyword arguments.
Attributes:
model (AutoModelForSequenceClassification): The loaded Hugging Face model.
tokenizer (AutoTokenizer): The loaded Hugging Face tokenizer.
"""

def __init__(
self,
input: BatchInput,
output: BatchOutput,
state: State,
**kwargs,
) -> None:
super().__init__(input=input, output=output, state=state)
log.info("Loading Hugging Face API server")

def load_models(self, model_path: str, tokenizer_path: str) -> None:
"""
Load the model and tokenizer.
Initializes the HuggingFaceClassificationAPI class.
Args:
model_path (str): The path to the saved model.
tokenizer_path (str): The path to the saved tokenizer.
input (BatchInput): The input data.
output (BatchOutput): The output data.
state (State): The state data.
**kwargs: Additional keyword arguments.
"""
log.info(f"Loading model from {model_path}")
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
log.info(f"Loading tokenizer from {tokenizer_path}")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
super().__init__(input=input, output=output, state=state)
log.info("Loading Hugging Face API server")

@cherrypy.expose
@cherrypy.tools.json_in()
@cherrypy.tools.json_out()
@cherrypy.tools.allow(methods=["POST"])
def classify(self) -> Dict[str, str]:
def classify(self) -> Dict[str, Any]:
"""
Classify the input text.
Returns:
Dict[str, str]: The classification result.
"""
data = cherrypy.request.json
data: Dict[str, str] = cherrypy.request.json
text = data.get("text", "")

inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
outputs = self.model(**inputs)
label_id = outputs.logits.argmax(-1).item()
label = self.model.config.id2label[label_id]
return {"label": label}

if next(self.model.parameters()).is_cuda:
inputs = {k: v.cuda() for k, v in inputs.items()}

with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
if next(self.model.parameters()).is_cuda:
logits = logits.cpu()
softmax = torch.nn.functional.softmax(logits, dim=-1)
scores = softmax.numpy().tolist() # Convert scores to list

id_to_label = dict(enumerate(self.model.config.id2label.values())) # type: ignore
label_scores = [{id_to_label[label_id]: score} for label_id, score in enumerate(scores[0])]

return {"input": text, "label_scores": label_scores}

0 comments on commit c510066

Please sign in to comment.