Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gsk 2559 add tabular classification pipeline #28

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 51 additions & 19 deletions giskard_cicd/loaders/huggingface_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from giskard.models.base import BaseModel
from giskard.models.huggingface import HuggingFaceModel
from transformers.pipelines import TextClassificationPipeline
from .tabular_classification_pipeline import TabularClassificationPipeline
from .tabular_regression_pipeline import TabularRegressionPipeline
from .tabular_pipeline import TabularPipeline
import requests

from .huggingface_inf_model import classification_model_from_inference_api
Expand Down Expand Up @@ -74,17 +77,20 @@ def load_giskard_model_dataset(

# Check that the dataset has the good feature names for the task.
logger.debug("Retrieving feature mapping")
if manual_feature_mapping is None:
if manual_feature_mapping is None and isinstance(hf_model, TextClassificationPipeline):
feature_mapping = self._get_feature_mapping(hf_model, hf_dataset)
logger.warn(
f'Feature mapping is not provided, using extracted "{feature_mapping}"'
)
else:
feature_mapping = manual_feature_mapping

df = hf_dataset.to_pandas().rename(
columns={v: k for k, v in feature_mapping.items()}
)
if feature_mapping is not None:
df = hf_dataset.to_pandas().rename(
columns={v: k for k, v in feature_mapping.items()}
)
else:
df = hf_dataset.to_pandas()

# remove the rows have multiple labels
# this is a hacky way to do it
Expand All @@ -96,7 +102,7 @@ def load_giskard_model_dataset(

# @TODO: currently for classification models only.
logger.debug("Retrieving classification label mapping")
if classification_label_mapping is None:
if classification_label_mapping is None and isinstance(hf_model, TextClassificationPipeline):
id2label = hf_model.model.config.id2label
logger.warn(f'Label mapping is not provided, using "{id2label}" from model')
else:
Expand All @@ -106,30 +112,42 @@ def load_giskard_model_dataset(
# need to include all labels
# rewrite this lambda function to include all labels
df.label = df.label.apply(lambda x: id2label[x[0]])
else:
elif getattr(df, "label", None) is not None:
# TODO: when the label for test is not provided, what do we do?
df["label"] = df.label.apply(lambda x: id2label[x] if x >= 0 else "-1")

# map the list of label ids to the list of labels
# df["label"] = df.label.apply(lambda x: [id2label[i] for i in x])
logger.debug("Wrapping dataset")

gsk_dataset = gsk.Dataset(
df,
name=f"HF {dataset}[{dataset_config}]({dataset_split}) for {model} model",
target="label",
column_types={"text": "text"},
validation=False,
)


logger.debug("Wrapping model")

gsk_model = self._get_gsk_model(
hf_model,
[id2label[i] for i in range(len(id2label))],
features=feature_mapping,
inference_type=inference_type,
device=self.device,
hf_token=inference_api_token,
)
if id2label is None and isinstance(hf_model, TabularPipeline):
gsk_model = gsk.Model(
lambda data: hf_model.predict(data),
model_type=hf_model._model_type,
name=f"{hf_model.model_id} HF pipeline",
feature_names=hf_model.config["features"],
classification_labels=hf_model.config["target_mapping"].values(),
)
else:
gsk_model = self._get_gsk_model(
hf_model,
[id2label[i] for i in range(len(id2label))],
features=feature_mapping,
inference_type=inference_type,
device=self.device,
hf_token=inference_api_token,
)

# Optimize batch size
if self.device.startswith("cuda"):
Expand Down Expand Up @@ -170,8 +188,19 @@ def load_dataset(

def load_model(self, model_id):
from transformers import pipeline

task = huggingface_hub.model_info(model_id).pipeline_tag
tags = huggingface_hub.model_info(model_id).tags
serialization = None
if not task and ("tabular" in tags and "classification" in tags):
task = ["tabular-classification"]

if "skops" in tags:
serialization = "skops"

if "tabular-classification" in task:
return TabularClassificationPipeline(task=task, model=model_id, model_id=model_id, serialization=serialization)
if "tabular-regression" in task:
return TabularRegressionPipeline(task=task, model=model_id, model_id=model_id, serialization=serialization)

return pipeline(task=task, model=model_id, device=self.device)

Expand Down Expand Up @@ -236,6 +265,8 @@ def _flatten_hf_dataset(self, hf_dataset, data_split=None):
Flatten the dataset to a pandas dataframe
"""
flat_dataset = pd.DataFrame()
if isinstance(hf_dataset, datasets.Dataset):
return hf_dataset
if isinstance(hf_dataset, datasets.DatasetDict):
keys = list(hf_dataset.keys())
for k in keys:
Expand All @@ -245,16 +276,15 @@ def _flatten_hf_dataset(self, hf_dataset, data_split=None):
break

# Otherwise infer one data split
if k.startswith("train"):
continue
elif k.startswith(data_split):
if k.startswith(data_split):
# TODO: only support one split for now
# Maybe we can merge all the datasets into one
flat_dataset = hf_dataset[k]
break
elif k.startswith("train"):
continue
else:
flat_dataset = hf_dataset[k]

# If there are only train datasets
if isinstance(flat_dataset, pd.DataFrame) and flat_dataset.empty:
flat_dataset = hf_dataset[keys[0]]
Expand All @@ -264,6 +294,8 @@ def _flatten_hf_dataset(self, hf_dataset, data_split=None):
def _get_feature_mapping(self, hf_model, hf_dataset):
if isinstance(hf_model, TextClassificationPipeline):
task_features = {"text": "string", "label": "class_label"}
elif "tabular" in hf_model.pipeline_tag:
raise NotImplementedError("Tabular model features cannot be auto-mapped.")
else:
msg = "Unsupported model type."
raise NotImplementedError(msg)
Expand Down
29 changes: 29 additions & 0 deletions giskard_cicd/loaders/tabular_classification_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from .tabular_pipeline import TabularPipeline

class TabularClassificationPipeline(TabularPipeline):
def __init__(self, *args, **kwargs):
self._model_type = "classification"
self._check_model_type(self._model_type)
self.pipeline_tag = "tabular-classification"
# get model parameter from args
super().__init__(*args, **kwargs)

def _sanitize_parameters(self, **kwargs):
kwargs = super()._sanitize_parameters(**kwargs)
return kwargs

def _check_model_type(self, model_type):
if model_type != self._model_type:
raise ValueError(
f"Pipeline is not of type {self._model_type} but {model_type}"
)

def _forward(self, *args, **kwargs):
return super()._forward(*args, **kwargs)

def preprocess(self, *args, **kwargs):
return super().preprocess(*args, **kwargs)

def postprocess(self, *args, **kwargs):
return super().postprocess(*args, **kwargs)

57 changes: 57 additions & 0 deletions giskard_cicd/loaders/tabular_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Any, Dict
from transformers import Pipeline, AutoModel, AutoConfig
import keras
import huggingface_hub
import joblib
import os
import json
from skops.io import load
from transformers.pipelines.base import GenericTensor
from transformers.utils import ModelOutput

class TabularPipeline(Pipeline):
def __init__(self, *args, **kwargs):
self._num_workers = 0
# get model parameter from args
self.model_id = kwargs.pop("model", None)
self.model_dir = huggingface_hub.snapshot_download(self.model_id)
serialization = kwargs.pop("serialization", None)
for f in os.listdir(self.model_dir):
if serialization == "skops" and ".pkl" in f:
self.model = load(self.model_dir + "/" + f)
if ".joblib" in f: # joblib
self.model = joblib.load(self.model_dir + "/" + f)
if "config.json" in f:
config_file = json.load(open(self.model_dir + "/" + f))
if "sklearn" in config_file.keys():
self.config = config_file["sklearn"]
if "columns" in self.config.keys():
self.config["features"] = self.config["columns"]
else:
self.config = config_file
if ".pt" in f: # pytorch
self.model = AutoModel.from_pretrained(self.model_dir)
if "config.json" in f:
self.model.config = AutoConfig.from_pretrained(self.model_dir)
if "model.pb" in f: # keras
self.model = keras.models.load_model(self.model_dir)
if "modelRun.json" in f:
raise ValueError(
"MLConsole models are not suppoerted."
)

def predict(self, *args, **kwargs):
return self.model.predict(*args, **kwargs)

def _sanitize_parameters(self, **kwargs):
kwargs = super()._sanitize_parameters(**kwargs)
return kwargs
def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput:
return super()._forward(input_tensors, **forward_parameters)

def preprocess(self, *args, **kwargs):
return super().preprocess(*args, **kwargs)

def postprocess(self, model_outputs: ModelOutput, **postprocess_parameters: Dict) -> Any:
return super().postprocess(model_outputs, **postprocess_parameters)

29 changes: 29 additions & 0 deletions giskard_cicd/loaders/tabular_regression_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from .tabular_pipeline import TabularPipeline

class TabularRegressionPipeline(TabularPipeline):
def __init__(self, *args, **kwargs):
self._model_type = "regression"
self._check_model_type(self._model_type)
self.pipeline_tag = "tabular-regression"
# get model parameter from args
super().__init__(*args, **kwargs)

def _sanitize_parameters(self, **kwargs):
kwargs = super()._sanitize_parameters(**kwargs)
return kwargs

def _check_model_type(self, model_type):
if model_type != self._model_type:
raise ValueError(
f"Pipeline is not of type {self._model_type} but {model_type}"
)

def _forward(self, *args, **kwargs):
return super()._forward(*args, **kwargs)

def preprocess(self, *args, **kwargs):
return super().preprocess(*args, **kwargs)

def postprocess(self, *args, **kwargs):
return super().postprocess(*args, **kwargs)