Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model config manager class #73

Merged
merged 70 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
87016a2
implent a model config manager class
Vela-zz Dec 10, 2023
2bc3b75
add test case for model management
Vela-zz Dec 10, 2023
34e49fc
apply format suggestion
Vela-zz Dec 10, 2023
083c612
apply format suggestion
Vela-zz Dec 10, 2023
2cdf43c
pydoc update & fix test case
Vela-zz Dec 10, 2023
99fe02e
implement a model config manager class
Vela-zz Dec 10, 2023
2843c88
add test case for model management
Vela-zz Dec 10, 2023
49983cd
apply format suggestion
Vela-zz Dec 10, 2023
63aa6e6
pydoc update & fix test case
Vela-zz Dec 10, 2023
66ef0c6
apply format suggestion
Vela-zz Dec 10, 2023
6cd3755
Merge branch 'zh_model_config' of https://github.com/Vela-zz/langchec…
Vela-zz Dec 10, 2023
b89ed30
pydoc update & fix test case
Vela-zz Dec 10, 2023
1a2b720
Merge branch 'zh_model_config' of https://github.com/Vela-zz/langchec…
Vela-zz Dec 23, 2023
2506ece
add model loader
Vela-zz Dec 23, 2023
57ea217
re-implent a model manager class
Vela-zz Dec 25, 2023
3bf196c
add update_metrics_for_model method
Vela-zz Dec 25, 2023
bb70f64
apply format suggestion
Vela-zz Dec 25, 2023
db15318
clean up model loader docstrings
yosukehigashi Dec 27, 2023
ddac3cf
fix format
yosukehigashi Dec 27, 2023
6b2a382
clean up docstrings in model management
yosukehigashi Dec 27, 2023
0e23c39
make self.config not None
yosukehigashi Dec 27, 2023
a1fd972
remove unnecessary noqa tags
yosukehigashi Dec 27, 2023
e80285d
clean up comments
yosukehigashi Dec 27, 2023
d841711
fix fetch_model format
yosukehigashi Dec 27, 2023
2aacbf2
fix ref based and source based format
yosukehigashi Dec 27, 2023
6d0a530
fix format in ref free
yosukehigashi Dec 27, 2023
b2805d6
add dependencies
Vela-zz Dec 27, 2023
bc443a6
fix model manager plot table problem
Vela-zz Dec 27, 2023
4433da1
fix typo mistakes
Vela-zz Dec 28, 2023
c540f12
move and update model manager
Vela-zz Feb 5, 2024
ea7564b
use yaml file for read ease
Vela-zz Feb 5, 2024
a1120db
apply manager changes to zh metric
Vela-zz Feb 5, 2024
b1718d6
update environment settings and delete test case
Vela-zz Feb 5, 2024
da11f46
apply format check suggestions
Vela-zz Feb 5, 2024
a83d6a3
add test case for model loader.
Vela-zz Feb 5, 2024
ac59b6b
add package data to pyproject.toml
yosukehigashi Feb 7, 2024
492396a
Merge branch 'zh_model_config' of https://github.com/vela-zz/langchec…
yosukehigashi Feb 7, 2024
e4091b3
package-data fix
yosukehigashi Feb 7, 2024
37c3884
try again
yosukehigashi Feb 7, 2024
bd3ab62
remove global value in metric
Vela-zz Feb 7, 2024
d8f20bc
clean load_sentence_transformers comments
yosukehigashi Feb 9, 2024
c46645b
clean load_auto_model_for_text_classification docstring
yosukehigashi Feb 9, 2024
ff90381
clean load_auto_model_for_seq2seq docstring
yosukehigashi Feb 9, 2024
e35b66e
clean up model loader imports
yosukehigashi Feb 9, 2024
a959633
clean up __load_config
yosukehigashi Feb 9, 2024
3240129
add comment for fetch_model
yosukehigashi Feb 9, 2024
78b296e
clean __set_model_for_metric
yosukehigashi Feb 9, 2024
8b247fd
minor cleanup
yosukehigashi Feb 9, 2024
42ccc66
clean validate_config
yosukehigashi Feb 9, 2024
f4bf665
remove unused import
yosukehigashi Feb 9, 2024
6f33c0a
Merge branch 'zh_model_config' of https://github.com/Vela-zz/langchec…
Vela-zz Feb 9, 2024
e7d42e9
add test case for model manager class
Vela-zz Feb 16, 2024
1086826
remove global value in metrics
Vela-zz Feb 16, 2024
89d9aaa
apply format suggestions
Vela-zz Feb 16, 2024
a09dbd3
make jp char and zh char show formally in test, not unicode
Vela-zz Feb 16, 2024
daeb706
fix import error in en detoxify raised by pyright
Vela-zz Feb 19, 2024
64f7e95
apply format check suggestion
Vela-zz Feb 19, 2024
6155623
apply format check suggestions and remove useless import
Vela-zz Feb 19, 2024
7f45a8c
remove unused imports
yosukehigashi Feb 20, 2024
0f15528
remove unused import
yosukehigashi Feb 20, 2024
5371f86
cleanup and docstrings
yosukehigashi Feb 20, 2024
57b864f
add tokenizer_revision for fine grained control
Vela-zz Feb 26, 2024
160a0d3
apply tokenizer_revision update to test case
Vela-zz Feb 26, 2024
10593fc
clean up docstring and comments
yosukehigashi Feb 28, 2024
7c770b0
fix typo
yosukehigashi Feb 28, 2024
2ce534f
specify which fields are optional in config
yosukehigashi Feb 28, 2024
86f55fe
removed unnecessary noqa
yosukehigashi Feb 28, 2024
f540159
fix yapf format
yosukehigashi Feb 28, 2024
fdd353e
fix yapf format
yosukehigashi Feb 28, 2024
b1d5d76
maximize disk space
yosukehigashi Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/langcheck/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from langcheck.metrics import en, ja, zh
from langcheck.metrics._model_management import ModelManager
from langcheck.metrics.en.reference_based_text_quality import (
rouge1, rouge2, rougeL, semantic_similarity)
from langcheck.metrics.en.reference_free_text_quality import (
Expand All @@ -13,6 +14,8 @@
is_json_array, is_json_object,
matches_regex, validation_fn)

_model_manager = ModelManager()

__all__ = [
'en',
'ja',
Expand Down
46 changes: 46 additions & 0 deletions src/langcheck/metrics/_model_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Optional, Tuple

from sentence_transformers import SentenceTransformer
from transformers.models.auto.modeling_auto import \
AutoModelForSequenceClassification
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.pipelines import pipeline


def load_sentence_transformers(model_name: str) -> SentenceTransformer:
"""
return a sentence-transformer model.

Args:
model_name: The model name of a sentence-transformers model
"""
return SentenceTransformer(model_name)


def load_auto_model_for_text_classification(model_name: str,
tokenizer_name: Optional[str],
revision: Optional[str])\
-> Tuple[AutoTokenizer,
AutoModelForSequenceClassification]:
"""
return a Huggingface text-classification pipeline.

Args:
model_name: The name of a sequenceclassification model on huggingface hub. # NOQA:E501
tokenizer_name: the name of a tokenizer on huggingface hub.
revisoin: the shorted sha1 string of a model
"""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision=revision)
model = AutoModelForSequenceClassification.from_pretrained(model_name, revision=revision) # NOQA: E501
return tokenizer, model


def load_pipeline_for_text_classification(model_name: str, **kwargs):
"""
return a Huggingface text-classification pipeline.

Args:
model_name: A huggingface model model for text classification.
"""
top_k = kwargs.pop('top_k', None)
return pipeline('text-classification', model=model_name, top_k=top_k)
226 changes: 226 additions & 0 deletions src/langcheck/metrics/_model_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import os
from copy import deepcopy
from functools import lru_cache
from pathlib import Path
from pprint import pprint
from typing import Optional, Tuple, Union

import pandas as pd
import requests
from configobj import ConfigObj
from sentence_transformers import SentenceTransformer
from transformers.models.auto.modeling_auto import \
AutoModelForSequenceClassification
from transformers.models.auto.tokenization_auto import AutoTokenizer

from ._model_loader import (load_auto_model_for_text_classification,
load_sentence_transformers)

# TODO: Use a ENUM class to parse these
VALID_METRIC_NAME = [
'factual_consistency', 'toxicity', 'sentiment', 'semantic_similarity'
]
VALID_LANGUAGE = ['zh']
VALID_LOADER = ['huggingface', 'sentence-transformers']


class ModelManager:
"""
A class to manage different models for multiple languages in the
langcheck.
This class allows setting and retrieving different model names.
(like sentiment_model, semantic_similarity_model, etc.) for each language.
It also supports loading model configurations from a file.
"""

def __init__(self):
"""
Initializes the ModelConfig with empty model dictionaries for each
language.
"""
self.config = None
self.__init__config()
self.validate_config()

def __init__config(self):
cwd = os.path.dirname(__file__)
self.config = ConfigObj(
os.path.join(Path(cwd), 'config', 'metric_config.ini')) # NOQA:E501

@lru_cache
def fetch_model(self, language: str, metric: str)\
-> Union[Tuple[AutoTokenizer, AutoModelForSequenceClassification],
SentenceTransformer]:
"""
return the model used in current metric for a given language.

Args:
language: The language for which to get the model.
metric_type: The metric name.
"""
if language in self.config: # type: ignore
if metric in self.config[language]: # type: ignore
# deep copy the confguration
# any action on config would not distrub self.config
config = deepcopy(
self.config[language][metric] # type: ignore[reportGeneralTypeIssues] # NOQA:E501
)
# get model name, model loader type
model_name, loader_type = config['model_name'], config[
'loader'] # type: ignore[reportGeneralTypeIssues] # NOQA:E501
# check if model version fixed
revision = config.pop("revision", None)
if loader_type == 'sentence-transformers':
if revision is not None:
print(
'Info: Sentence-Transformers do not support model version fixed yet' # NOQA: E501
)
model = load_sentence_transformers(model_name=model_name)
return model
elif loader_type == 'huggingface':
tokenizer_name = config.pop('tokenizer_name', None)
tokenizer, model = load_auto_model_for_text_classification(
model_name=model_name, # NOQA:E501
tokenizer_name=tokenizer_name, # NOQA:E501
revision=revision # NOQA:E501
)
return tokenizer, model
else:
raise KeyError(f'Loader {loader_type} not supported yet.')
else:
raise KeyError(f'Metric {metric} not supported yet.')
else:
raise KeyError(f'language {language} not supported yet')

def list_current_model_in_use(self, language='all', metric='all'):
""" list model in use.

Args:
language: The abbrevation name of language.
metric: The evaluation metric name.
"""
df = pd.DataFrame.from_records(
[
(lang, metric_name, key, value)
for lang, lang_model_settings in
self.config.items() # type: ignore # NOQA:E501
for metric_name, model_settings in
lang_model_settings.items() # type: ignore # NOQA:E501
for key, value in model_settings.items()
],
columns=['language', 'metric_name', 'attribute', 'value'])

# the code below would generate a dataframe:
# |index| language | metric_name | loader | model_name | revision |
# |.....|..........|.............|........|............|..........|
df_pivot = df.pivot_table(
index=['language', 'metric_name'],
columns="attribute",
values="value",
aggfunc='first').reset_index().drop(
columns=["attribute"]).reset_index() # NOQA:E501
df_pivot.columns = [
'language', 'metric_name', 'loader', 'model_name', 'revision'
] # NOQA:E501

if language == 'all' and metric == 'all':
pprint(df_pivot)
else:
if language != "all":
df_pivot = df_pivot.loc[df_pivot.language == language]
if metric != 'all':
df_pivot = df_pivot.loc[df_pivot.metric_name == metric]
pprint(df_pivot)

def validate_config(self, language='all', metric='all'):
"""validate configuration.

Args:
language (str, optional):the name of the language. Defaults to 'all'. # NOQA:E501
metric (str, optional): the name of evaluation metric. Defaults to 'all'. # NOQA:E501
"""

def check_model_availability(model_name, revision):
if revision is None:
url = f"https://huggingface.co/api/models/{model_name}"
else:
url = f"https://huggingface.co/api/models/{model_name}/revision/{revision}" # NOQA:E501
response = requests.get(url)
return response.status_code == 200

config = deepcopy(self.config)
for lang, lang_setting in config.items(): # type: ignore # NOQA:E501
if language == 'all' or lang == language:
for metric_name, model_setting in lang_setting.items( # type: ignore # NOQA:E501
):
if metric == 'all' or metric_name == metric:
# if model name not set
if 'model_name' not in model_setting:
raise KeyError(
f'{lang} metrics {metric_name} need a model, but found None!' # NOQA:E501
)
if 'loader' not in model_setting:
raise KeyError(
f'Metrics {metric_name} need a loader, but found None!' # NOQA:E501
)
# check if the model and revision is available on huggingface Hub # NOQA:E501
loader_type = model_setting.pop('loader')
if loader_type == 'huggingface':
model_name = model_setting.pop('model_name')
revision = model_setting.pop('revision', None)
if not check_model_availability(
model_name, revision): # NOQA:E501
raise ValueError(
f"""Cannot find {model_name} with # NOQA:E501
{revision} and Huggingface Hub"""
)
elif loader_type not in VALID_LOADER:
raise ValueError(
f'loader type should in {VALID_LOADER}'
) # NOQA: E501
# may also need other validate method for other loader
# not found yet
print('Configuration Validation Passed')

def set_model_for_metric(self, language: str, metric: str, model_name: str,
loader: Optional[str], **kwargs):
"""set model for specified metric in specified language

Args:
language (str): the name of the lanuage,
metric (str): the name of the evaluation metrics,
loader(str): the loader of the model, optional,
model_name(str): the name of the model,
tokenizer_name(str): optional, the name of the tokenizer,
revision(str): a version string of the model.
"""
config_copy = deepcopy(self.config)
try:
if language not in VALID_LANGUAGE:
raise ValueError('Language {language} not supported yet')

if metric not in self.config[language]: # type: ignore # NOQA:E501
raise ValueError(
'Language {language} not supported {metric} yet'
) # NOQA:E501

config = self.config[language][metric] # type: ignore # NOQA:E501
config['loader'] = loader
config['model_name'] = model_name
# if tokenizer_name is different with model
tokenizer_name = kwargs.pop('tokenizer_name', None)
if tokenizer_name:
config['tokenizer_name'] = tokenizer_name
# if model's revision is pinned
revision = kwargs.pop('revision', None)
if revision:
config['revision'] = revision
# validate the change
if self.validate_config(language=language, metric=metric):
# clear the LRU cache to make the config change
# reflected imediately
self.fetch_model.cache_clear()
except (ValueError, KeyError) as err:
# trace back the configuration
self.config = config_copy
raise err
22 changes: 22 additions & 0 deletions src/langcheck/metrics/config/metric_config.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[zh]
[[semantic_similarity]]
# According to the C-MTEB Benchmark
# (https://github.com/FlagOpen/FlagEmbedding/tree/master/C_MTEB)
# the 3 models of different sizes provided BAAI are the best on the
# embedding task
# Ref: https://huggingface.co/BAAI/bge-base-zh-v1.5
# Using this model, it is hard to find two sentence where cos_sim < 0.25.
model_name = BAAI/bge-base-zh-v1.5
revision = f03589c
loader = sentence-transformers
[[sentiment]]
model_name = IDEA-CCNL/Erlangshen-Roberta-110M-Sentiment
loader = huggingface
[[toxicity]]
model_name = alibaba-pai/pai-bert-base-zh-llm-risk-detection
loader = huggingface
revision = 0a61c79744cb0173216f015ffecc1ea81c4e0229
[[factual_consistency]]
model_name = Helsinki-NLP/opus-mt-zh-en
loader = huggingface
revision = cf109095479db38d6df799875e34039d4938aaa6
15 changes: 6 additions & 9 deletions src/langcheck/metrics/zh/reference_based_text_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,13 @@ def semantic_similarity(
openai_args)
metric_value.language = 'zh'
return metric_value
# lazy import
from langcheck.metrics import _model_manager
model = _model_manager.fetch_model(language='zh',
metric_type="semantic_similarity")

# According to the C-MTEB Benchmark
# (https://github.com/FlagOpen/FlagEmbedding/tree/master/C_MTEB)
# the 3 models of different sizes provided BAAI are the best on the
# embedding task
# Ref: https://huggingface.co/BAAI/bge-base-zh-v1.5
# Using this model, it is hard to find two sentence where cos_sim < 0.25.
model = SentenceTransformer('BAAI/bge-base-zh-v1.5')
generated_embeddings = model.encode(generated_outputs)
reference_embeddings = model.encode(reference_outputs)
generated_embeddings = model.encode(generated_outputs) # type: ignore[reportGeneralTypeIssues] # NOQA: E501
reference_embeddings = model.encode(reference_outputs) # type: ignore[reportGeneralTypeIssues] # NOQA: E501
cosine_scores = util.pairwise_cos_sim(
generated_embeddings, # type: ignore[reportGeneralTypeIssues]
reference_embeddings # type: ignore[reportGeneralTypeIssues]
Expand Down
13 changes: 11 additions & 2 deletions src/langcheck/metrics/zh/reference_free_text_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ def sentiment(
_sentiment_pipeline = pipeline(
'sentiment-analysis', model=_sentiment_model_path
) # type: ignore[reportGeneralTypeIssues] # NOQA: E501
# {0:"Negative", 1:'Positive'}
# # {0:"Negative", 1:'Positive'}
from langcheck.metrics import _model_manager
tokenizer, model = _model_manager.fetch_model(lanaguage='zh', metric='sentiment') # NOQA: E501
_sentiment_pipeline = pipeline(
'sentiment-analysis', model=model, tokenizer=tokenizer) # type: ignore[reportGeneralTypeIssues] # NOQA: E501
_model_id2label = _sentiment_pipeline.model.config.id2label
_predict_result = _sentiment_pipeline(
generated_outputs
Expand Down Expand Up @@ -210,8 +214,13 @@ def _toxicity_local(generated_outputs: List[str]) -> List[float]:
global _toxicity_model_path
# this pipeline output predict probability for each text on each label.
# the output format is List[List[Dict(str)]]
from langcheck.metrics import _model_manager
tokenizer, model = _model_manager.fetch_model(language='zh',
metric_type="toxicity")

_toxicity_pipeline = pipeline('text-classification',
model=_toxicity_model_path,
model=model,
tokenizer=tokenizer, # type: ignore[reportOptionalIterable] # NOQA: E501
top_k=5)

# {'Normal': 0, 'Pulp': 1, 'Sex': 2, 'Other Risk': 3, 'Adult': 4}
Expand Down
4 changes: 3 additions & 1 deletion src/langcheck/metrics/zh/source_based_text_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ def factual_consistency(

global _factual_consistency_translation_pipeline
if _factual_consistency_translation_pipeline is None:
from langcheck.metrics import _model_manager
tokenizer, model = _model_manager.fetch_model(language='zh', metric_type='factual') # NOQA: E501
_factual_consistency_translation_pipeline = pipeline(
'translation', model=_factual_consistency_translation_model_path)
'translation', model=model, tokenizer=tokenizer) # type: ignore

# Translate the sources and generated outputs to English.
# Currently, the type checks are not working for the pipeline, since
Expand Down
Loading