Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model config manager class #73

Merged
merged 70 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
87016a2
implent a model config manager class
Vela-zz Dec 10, 2023
2bc3b75
add test case for model management
Vela-zz Dec 10, 2023
34e49fc
apply format suggestion
Vela-zz Dec 10, 2023
083c612
apply format suggestion
Vela-zz Dec 10, 2023
2cdf43c
pydoc update & fix test case
Vela-zz Dec 10, 2023
99fe02e
implement a model config manager class
Vela-zz Dec 10, 2023
2843c88
add test case for model management
Vela-zz Dec 10, 2023
49983cd
apply format suggestion
Vela-zz Dec 10, 2023
63aa6e6
pydoc update & fix test case
Vela-zz Dec 10, 2023
66ef0c6
apply format suggestion
Vela-zz Dec 10, 2023
6cd3755
Merge branch 'zh_model_config' of https://github.com/Vela-zz/langchec…
Vela-zz Dec 10, 2023
b89ed30
pydoc update & fix test case
Vela-zz Dec 10, 2023
1a2b720
Merge branch 'zh_model_config' of https://github.com/Vela-zz/langchec…
Vela-zz Dec 23, 2023
2506ece
add model loader
Vela-zz Dec 23, 2023
57ea217
re-implent a model manager class
Vela-zz Dec 25, 2023
3bf196c
add update_metrics_for_model method
Vela-zz Dec 25, 2023
bb70f64
apply format suggestion
Vela-zz Dec 25, 2023
db15318
clean up model loader docstrings
yosukehigashi Dec 27, 2023
ddac3cf
fix format
yosukehigashi Dec 27, 2023
6b2a382
clean up docstrings in model management
yosukehigashi Dec 27, 2023
0e23c39
make self.config not None
yosukehigashi Dec 27, 2023
a1fd972
remove unnecessary noqa tags
yosukehigashi Dec 27, 2023
e80285d
clean up comments
yosukehigashi Dec 27, 2023
d841711
fix fetch_model format
yosukehigashi Dec 27, 2023
2aacbf2
fix ref based and source based format
yosukehigashi Dec 27, 2023
6d0a530
fix format in ref free
yosukehigashi Dec 27, 2023
b2805d6
add dependencies
Vela-zz Dec 27, 2023
bc443a6
fix model manager plot table problem
Vela-zz Dec 27, 2023
4433da1
fix typo mistakes
Vela-zz Dec 28, 2023
c540f12
move and update model manager
Vela-zz Feb 5, 2024
ea7564b
use yaml file for read ease
Vela-zz Feb 5, 2024
a1120db
apply manager changes to zh metric
Vela-zz Feb 5, 2024
b1718d6
update environment settings and delete test case
Vela-zz Feb 5, 2024
da11f46
apply format check suggestions
Vela-zz Feb 5, 2024
a83d6a3
add test case for model loader.
Vela-zz Feb 5, 2024
ac59b6b
add package data to pyproject.toml
yosukehigashi Feb 7, 2024
492396a
Merge branch 'zh_model_config' of https://github.com/vela-zz/langchec…
yosukehigashi Feb 7, 2024
e4091b3
package-data fix
yosukehigashi Feb 7, 2024
37c3884
try again
yosukehigashi Feb 7, 2024
bd3ab62
remove global value in metric
Vela-zz Feb 7, 2024
d8f20bc
clean load_sentence_transformers comments
yosukehigashi Feb 9, 2024
c46645b
clean load_auto_model_for_text_classification docstring
yosukehigashi Feb 9, 2024
ff90381
clean load_auto_model_for_seq2seq docstring
yosukehigashi Feb 9, 2024
e35b66e
clean up model loader imports
yosukehigashi Feb 9, 2024
a959633
clean up __load_config
yosukehigashi Feb 9, 2024
3240129
add comment for fetch_model
yosukehigashi Feb 9, 2024
78b296e
clean __set_model_for_metric
yosukehigashi Feb 9, 2024
8b247fd
minor cleanup
yosukehigashi Feb 9, 2024
42ccc66
clean validate_config
yosukehigashi Feb 9, 2024
f4bf665
remove unused import
yosukehigashi Feb 9, 2024
6f33c0a
Merge branch 'zh_model_config' of https://github.com/Vela-zz/langchec…
Vela-zz Feb 9, 2024
e7d42e9
add test case for model manager class
Vela-zz Feb 16, 2024
1086826
remove global value in metrics
Vela-zz Feb 16, 2024
89d9aaa
apply format suggestions
Vela-zz Feb 16, 2024
a09dbd3
make jp char and zh char show formally in test, not unicode
Vela-zz Feb 16, 2024
daeb706
fix import error in en detoxify raised by pyright
Vela-zz Feb 19, 2024
64f7e95
apply format check suggestion
Vela-zz Feb 19, 2024
6155623
apply format check suggestions and remove useless import
Vela-zz Feb 19, 2024
7f45a8c
remove unused imports
yosukehigashi Feb 20, 2024
0f15528
remove unused import
yosukehigashi Feb 20, 2024
5371f86
cleanup and docstrings
yosukehigashi Feb 20, 2024
57b864f
add tokenizer_revision for fine grained control
Vela-zz Feb 26, 2024
160a0d3
apply tokenizer_revision update to test case
Vela-zz Feb 26, 2024
10593fc
clean up docstring and comments
yosukehigashi Feb 28, 2024
7c770b0
fix typo
yosukehigashi Feb 28, 2024
2ce534f
specify which fields are optional in config
yosukehigashi Feb 28, 2024
86f55fe
removed unnecessary noqa
yosukehigashi Feb 28, 2024
f540159
fix yapf format
yosukehigashi Feb 28, 2024
fdd353e
fix yapf format
yosukehigashi Feb 28, 2024
b1d5d76
maximize disk space
yosukehigashi Feb 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ dependencies = [
'tokenizers >= 0.13.2; python_version >= "3.11"', # See https://github.com/citadel-ai/langcheck/pull/45
'torch >= 2',
'transformers >= 4.6',
"unidic-lite >= 1.0.1" # For tokenizer of metrics.ja.toxicity()
"unidic-lite >= 1.0.1", # For tokenizer of metrics.ja.toxicity()
"tabulate >= 0.9.0", # For model manager paint table
"omegaconf >= 2.3.0" # For model manager paint table
]
requires-python = ">=3.8"

Expand Down Expand Up @@ -80,3 +82,9 @@ ignore = [
markers = [
"optional: marks tests as optional",
]

[tool.setuptools.packages.find]
where = ["src"]

[tool.setuptools.package-data]
langcheck = ["metrics/model_manager/config/*.yaml"]
1 change: 1 addition & 0 deletions src/langcheck/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@
'semantic_similarity',
'sentiment',
'toxicity',
'model_manager'
]
11 changes: 11 additions & 0 deletions src/langcheck/metrics/model_manager/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from ._model_loader import (load_auto_model_for_seq2seq,
load_auto_model_for_text_classification,
load_sentence_transformers)
from ._model_management import ModelManager

manager = ModelManager()

__all__ = [
"manager", "load_sentence_transformers", "load_auto_model_for_seq2seq",
"load_auto_model_for_text_classification"
]
87 changes: 87 additions & 0 deletions src/langcheck/metrics/model_manager/_model_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Optional, Tuple

from sentence_transformers import SentenceTransformer
from transformers.models.auto.modeling_auto import (
AutoModelForSeq2SeqLM, AutoModelForSequenceClassification)
from transformers.models.auto.tokenization_auto import AutoTokenizer


def load_sentence_transformers(
model_name: str,
tokenizer_name: Optional[str] = None,
revision: Optional[str] = None) -> SentenceTransformer:
'''
Loads a SentenceTransformer model.

This function currently does not support specifying a tokenizer or a
revision. If these arguments are provided, a warning message will be
printed.

Args:
model_name: The name of the SentenceTransformer model to load.
tokenizer_name: The name of the tokenizer to use. Currently not
supported.
revision: The model revision to load. Currently not supported.

Returns:
model: The loaded SentenceTransformer model.
'''
if revision is not None:
print("Warning: Specifying a revision is not currently supported.")
if tokenizer_name is not None:
print("Warning: Customizing the tokenizer is not currently supported.")

model = SentenceTransformer(model_name)
return model


def load_auto_model_for_text_classification(
model_name: str,
tokenizer_name: Optional[str] = None,
revision: Optional[str] = None
) -> Tuple[AutoTokenizer, AutoModelForSequenceClassification]:
'''
Loads a sequence classification model and its tokenizer.

Args:
model_name: The name of the sequence-classification model to load.
tokenizer_name: The name of the tokenizer to load. If None, the
tokenizer associated with the model will be loaded.
revision: The model revision to load.

Returns:
tokenizer: The loaded tokenizer.
model: The loaded sequence classification model.
'''
if tokenizer_name is None:
tokenizer_name = model_name
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision=revision)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, revision=revision)
return tokenizer, model # type: ignore


def load_auto_model_for_seq2seq(
model_name: str,
tokenizer_name: Optional[str] = None,
revision: Optional[str] = None
) -> Tuple[AutoTokenizer, AutoModelForSeq2SeqLM]:
'''
Loads a sequence-to-sequence model and its tokenizer.

Args:
model_name: The name of the sequence-classification model to load.
tokenizer_name: The name of the tokenizer to load. If None, the
tokenizer associated with the model will be loaded.
revision: The model revision to load.

Returns:
tokenizer: The loaded tokenizer.
model: The loaded sequence-to-sequence model.
'''
if tokenizer_name is None:
tokenizer_name = model_name
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision=revision)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, revision=revision) # NOQA: E501
return tokenizer, model # type: ignore
257 changes: 257 additions & 0 deletions src/langcheck/metrics/model_manager/_model_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
import os
from copy import deepcopy
from functools import lru_cache
from typing import Optional, Tuple, Union

import pandas as pd
import requests
from omegaconf import OmegaConf
from sentence_transformers import SentenceTransformer
from tabulate import tabulate
from transformers.models.auto.modeling_auto import (
AutoModelForSeq2SeqLM, AutoModelForSequenceClassification)
from transformers.models.auto.tokenization_auto import AutoTokenizer

from ._model_loader import (load_auto_model_for_seq2seq,
load_auto_model_for_text_classification,
load_sentence_transformers)

LOADER_MAP = {
"load_sentence_transformers":
load_sentence_transformers,
"load_auto_model_for_text_classification":
load_auto_model_for_text_classification,
"load_auto_model_for_seq2seq":
load_auto_model_for_seq2seq
}
VALID_LOADER_FUNCTION = LOADER_MAP.keys()
VALID_METRICS = [
'semantic_similarity', 'sentiment', 'toxicity', 'factual_consistency'
]

VALID_METRIC_ATTRIBUTE = [
'model_revision', 'model_revision', 'loader', 'tokenizer_name'
]
VALID_LANGUAGE = ['zh']


def check_model_availability(model_name: str, revision: Optional[str]) -> bool:
# TODO: add local cached model availability check for offline environment
if revision is None:
url = f"https://huggingface.co/api/models/{model_name}"
else:
url = f"https://huggingface.co/api/models/{model_name}/revision/{revision}" # NOQA:E501
response = requests.get(url, timeout=(1.0, 1.0))
return response.status_code == 200


class ModelManager:
'''
A class to manage different models for multiple languages in LangCheck.
This class allows setting and retrieving different model names (like
sentiment_model, semantic_similarity_model, etc.) for each language.
It also supports loading model configurations from a file.
'''

def __init__(self):
'''
Initializes the ModelConfig with empty model dictionaries for each
language.
'''
self.config = OmegaConf.create()
cwd = os.path.dirname(__file__)
default_config_file_path = os.path.join(cwd, "config",
"metric_config.yaml")
self.__load_config(default_config_file_path)

def __load_config(self, path: str) -> None:
'''
Loads the model configuration from a file.

Args:
path: The path to the configuration file.
'''
conf = OmegaConf.load(path)

for lang, lang_conf in conf.items():
for metric_name, metric_conf in lang_conf.items():
# check model availbility, if key not in conf
# omega conf will return None in default
self.__set_model_for_metric(language=lang,
metric=metric_name,
**metric_conf)
print('Configuration Load Succeeded!')

@lru_cache
def fetch_model(
self, language: str, metric: str
) -> Union[Tuple[AutoTokenizer, AutoModelForSequenceClassification], Tuple[
AutoTokenizer, AutoModelForSeq2SeqLM], SentenceTransformer]:
'''
Return the model (and if applicable, the tokenizer) used for the given
metric and language.

Args:
language: The language for which to get the model
metric_type: The metric name

Returns:
A (tokenizer, modle) tuple, or just the model depending on the
loader function.
'''
if language in self.config:
if metric in self.config[language]:
# Deep copy the confguration so that changes to `config` would
# not affect the original `self.config`.
config = deepcopy(self.config[language][metric])
# Get model loader function
loader_func = config.pop('loader_func')
loader = LOADER_MAP[loader_func]
# Call the loader function with the model_name, tokenizer_name
# (optional), and revision (optional) as arguments
return loader(**config)
else:
raise KeyError(f'Metric {metric} not supported yet.')
else:
raise KeyError(f'Language {language} not supported yet')

@staticmethod
def validate_config(config, language='all', metric='all') -> None:
'''
Validate configuration.

Args:
config: The configuration dictionary to validate.
language: The name of the language. Defaults to 'all'.
metric: The name of the metric. Defaults to 'all'.
'''
config = deepcopy(config)
for lang, lang_setting in config.items():
if language != 'all' and lang != language:
continue
for metric_name, model_setting in lang_setting.items():
if metric != 'all' and metric_name != metric:
continue

# Check that the model name and loader function are set
if 'model_name' not in model_setting:
raise KeyError(
f'{lang} metrics {metric_name} need a model, but found None!' # NOQA:E501
)
if 'loader_func' not in model_setting:
raise KeyError(
f'Metrics {metric_name} need a loader, but found None!' # NOQA:E501
)
loader_func = model_setting.pop('loader_func', None)
if loader_func not in VALID_LOADER_FUNCTION:
raise ValueError(
f'loader type should in {VALID_LOADER_FUNCTION}')

# Check that the model and revision are available on the Hugging
# Face Hub
model_name = model_setting.pop('model_name')
revision = model_setting.pop('revision', None)
if not check_model_availability(model_name, revision):
raise ValueError(
f'Cannot find {model_name} with {revision} and Huggingface Hub' # NOQA:E501
)

def __set_model_for_metric(self, language: str, metric: str,
model_name: str, loader_func: str,
**kwargs) -> None:
'''
Set model for specified metric in specified language.

Args:
language: The name of the language
metric: The name of the evaluation metric
model_name: The name of the model
loader_func: The loader function of the model
tokenizer_name: (Optional) The name of the tokenizer
revision: (Optional) A version string of the model
'''
config_copy = deepcopy(self.config)
try:
if language not in VALID_LANGUAGE:
raise KeyError('Language {language} not supported yet')

if metric not in VALID_METRICS:
raise KeyError(
f'Metric {metric} not supported for language {language} yet'
)

# Initialize the configuration for the language and metric if it
# doesn't exist
if self.config.get(language) is None:
self.config[language] = {}
if self.config.get(language).get(metric) is None:
self.config[language][metric] = {}

detail_config = self.config[language][metric]
# Set the loader function and model name
detail_config['loader_func'] = loader_func
detail_config['model_name'] = model_name

# If tokenizer_name is different from model_name
tokenizer_name = kwargs.pop('tokenizer_name', None)
if tokenizer_name:
detail_config['tokenizer_name'] = tokenizer_name
# If model's revision is pinned
revision = kwargs.pop('model_revision', None)
if revision:
detail_config['revision'] = revision

# Validate the change
if ModelManager.validate_config(self.config,
language=language,
metric=metric):
# Clear the LRU cache to make the config change reflected
# immediately
self.fetch_model.cache_clear()
except (ValueError, KeyError) as err:
# If an error occurred, restore the original configuration
self.config = config_copy
raise err

def list_current_model_in_use(self, language='all', metric='all') -> None:
'''
List the models currently in use.

Args:
language: The abbrevation name of language
metric: The evaluation metric name
'''
df = pd.DataFrame.from_records(
[(lang, metric_name, key, value)
for lang, lang_model_settings in self.config.items()
for metric_name, model_settings in lang_model_settings.items()
for key, value in model_settings.items()],
columns=['language', 'metric_name', 'attribute', 'value'])
# The code below would generate a dataframe:
# |index| language | metric_name | loader | model_name | revision |
# |.....|..........|.............|........|............|..........|
df_pivot = df.pivot_table(index=['language', 'metric_name'],
columns="attribute",
values="value",
aggfunc='first').reset_index().rename_axis(
None, axis=1)
df_pivot.columns = [
'language', 'metric_name', 'loader', 'model_name', 'revision'
]

if language == 'all' and metric == 'all':
print(
tabulate(
df_pivot, # type: ignore
headers=df_pivot.columns, # type: ignore
tablefmt="github"))
else:
if language != "all":
df_pivot = df_pivot.loc[df_pivot.language == language]
if metric != 'all':
df_pivot = df_pivot.loc[df_pivot.metric_name == metric]
print(
tabulate(
df_pivot, # type: ignore
headers=df_pivot.columns, # type: ignore
tablefmt="github"))
Loading
Loading