diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index f275206f..cc0060e0 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -31,6 +31,19 @@ jobs: pip install --upgrade pip pip install .[dev] + # Remove unneeded system libraries to maximize disk space + # https://github.com/easimon/maximize-build-space/blob/master/action.yml + # https://github.com/actions/virtual-environments/issues/2840#issuecomment-790492173 + - name: Maximize disk space + run: | + echo "Available disk space (before):" + df -h + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf /usr/local/lib/android + echo "Available disk space (after):" + df -h + # Run integration tests - name: Test run: | diff --git a/pyproject.toml b/pyproject.toml index f7f2c1a0..84c08f39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,9 @@ dependencies = [ 'tokenizers >= 0.13.2; python_version >= "3.11"', # See https://github.com/citadel-ai/langcheck/pull/45 'torch >= 2', 'transformers >= 4.6', - "unidic-lite >= 1.0.1" # For tokenizer of metrics.ja.toxicity() + "unidic-lite >= 1.0.1", # For tokenizer of metrics.ja.toxicity() + "tabulate >= 0.9.0", # For model manager paint table + "omegaconf >= 2.3.0" # For model manager paint table ] requires-python = ">=3.8" @@ -80,3 +82,10 @@ ignore = [ markers = [ "optional: marks tests as optional", ] +disable_test_id_escaping_and_forfeit_all_rights_to_community_support = true + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.setuptools.package-data] +langcheck = ["metrics/model_manager/config/*.yaml"] \ No newline at end of file diff --git a/src/langcheck/metrics/en/_detoxify.py b/src/langcheck/metrics/en/_detoxify.py index ed4efdad..0699d870 100644 --- a/src/langcheck/metrics/en/_detoxify.py +++ b/src/langcheck/metrics/en/_detoxify.py @@ -1,7 +1,8 @@ from typing import List, Tuple import torch -from transformers import BertForSequenceClassification, BertTokenizer +from transformers.models.bert.modeling_bert import BertForSequenceClassification +from transformers.models.bert.tokenization_bert import BertTokenizer def load_checkpoint( diff --git a/src/langcheck/metrics/model_manager/__init__.py b/src/langcheck/metrics/model_manager/__init__.py new file mode 100644 index 00000000..24e9ef91 --- /dev/null +++ b/src/langcheck/metrics/model_manager/__init__.py @@ -0,0 +1,3 @@ +from ._model_management import ModelManager + +manager = ModelManager() diff --git a/src/langcheck/metrics/model_manager/_model_loader.py b/src/langcheck/metrics/model_manager/_model_loader.py new file mode 100644 index 00000000..f147de3b --- /dev/null +++ b/src/langcheck/metrics/model_manager/_model_loader.py @@ -0,0 +1,96 @@ +from typing import Optional, Tuple + +from sentence_transformers import SentenceTransformer +from transformers.models.auto.modeling_auto import ( + AutoModelForSeq2SeqLM, AutoModelForSequenceClassification) +from transformers.models.auto.tokenization_auto import AutoTokenizer + + +def load_sentence_transformers( + model_name: str, + model_revision: Optional[str] = None, + tokenizer_name: Optional[str] = None, + tokenizer_revision: Optional[str] = None) -> SentenceTransformer: + ''' + Loads a SentenceTransformer model. + + This function currently does not support specifying a tokenizer or a + revision. If these arguments are provided, a warning message will be + printed. + + Args: + model_name: The name of the SentenceTransformer model to load. + tokenizer_name: The name of the tokenizer to use. Currently not + supported. + model_revision: The model revision to load. Currently not supported. + tokenizerl_revision: The tokenizedr revision to load. Currently not + supported. + + Returns: + model: The loaded SentenceTransformer model. + ''' + if model_revision is not None or tokenizer_revision is not None: + print("Warning: Specifying a revision is not currently supported.") + if tokenizer_name is not None: + print("Warning: Customizing the tokenizer is not currently supported.") + + model = SentenceTransformer(model_name) + return model + + +def load_auto_model_for_text_classification( + model_name: str, + model_revision: Optional[str] = None, + tokenizer_name: Optional[str] = None, + tokenizer_revision: Optional[str] = None +) -> Tuple[AutoTokenizer, AutoModelForSequenceClassification]: + ''' + Loads a sequence classification model and its tokenizer. + + Args: + model_name: The name of the sequence-classification model to load. + tokenizer_name: The name of the tokenizer to load. If None, the + tokenizer associated with the model will be loaded. + model_revision: The model revision to load. + tokenizer_revision: the tokenizer revision to load. + + Returns: + tokenizer: The loaded tokenizer. + model: The loaded sequence classification model. + ''' + if tokenizer_name is None: + tokenizer_name = model_name + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, + revision=tokenizer_revision) + model = AutoModelForSequenceClassification.from_pretrained( + model_name, revision=model_revision) + return tokenizer, model # type: ignore + + +def load_auto_model_for_seq2seq( + model_name: str, + model_revision: Optional[str] = None, + tokenizer_name: Optional[str] = None, + tokenizer_revision: Optional[str] = None +) -> Tuple[AutoTokenizer, AutoModelForSeq2SeqLM]: + ''' + Loads a sequence-to-sequence model and its tokenizer. + + Args: + model_name: The name of the sequence-classification model to load. + tokenizer_name: The name of the tokenizer to load. If None, the + tokenizer associated with the model will be loaded. + model_revision: The model revision to load. + tokenizer_revision: the tokenizer revision to load + + Returns: + tokenizer: The loaded tokenizer. + model: The loaded sequence-to-sequence model. + ''' + if tokenizer_name is None: + tokenizer_name = model_name + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, + revision=tokenizer_revision) + model = AutoModelForSeq2SeqLM.from_pretrained(model_name, + revision=model_revision) + return tokenizer, model # type: ignore diff --git a/src/langcheck/metrics/model_manager/_model_management.py b/src/langcheck/metrics/model_manager/_model_management.py new file mode 100644 index 00000000..fd686bf0 --- /dev/null +++ b/src/langcheck/metrics/model_manager/_model_management.py @@ -0,0 +1,268 @@ +import os +from copy import deepcopy +from functools import lru_cache +from typing import Optional, Tuple, Union + +import pandas as pd +import requests +from omegaconf import OmegaConf +from sentence_transformers import SentenceTransformer +from tabulate import tabulate +from transformers.models.auto.modeling_auto import ( + AutoModelForSeq2SeqLM, AutoModelForSequenceClassification) +from transformers.models.auto.tokenization_auto import AutoTokenizer + +from ._model_loader import (load_auto_model_for_seq2seq, + load_auto_model_for_text_classification, + load_sentence_transformers) + +LOADER_MAP = { + "load_sentence_transformers": + load_sentence_transformers, + "load_auto_model_for_text_classification": + load_auto_model_for_text_classification, + "load_auto_model_for_seq2seq": + load_auto_model_for_seq2seq +} +VALID_LOADER_FUNCTION = LOADER_MAP.keys() +VALID_METRICS = [ + 'semantic_similarity', 'sentiment', 'toxicity', 'factual_consistency' +] +VALID_LANGUAGE = ['zh'] + + +def check_model_availability(model_name: str, revision: Optional[str]) -> bool: + # TODO: add local cached model availability check for offline environment + if revision is None or revision == "": + url = f"https://huggingface.co/api/models/{model_name}" + else: + url = f"https://huggingface.co/api/models/{model_name}/revision/{revision}" # NOQA:E501 + response = requests.get(url, timeout=(1.0, 1.0)) + return response.status_code == 200 + + +class ModelManager: + ''' + A class to manage different models for multiple languages in LangCheck. + This class allows setting and retrieving different model names (like + sentiment_model, semantic_similarity_model, etc.) for each language. + It also supports loading model configurations from a file. + ''' + + def __init__(self): + ''' + Initializes the ModelConfig with empty model dictionaries for each + language. + ''' + self.config = OmegaConf.create() + cwd = os.path.dirname(__file__) + default_config_file_path = os.path.join(cwd, "config", + "metric_config.yaml") + self.__load_config(default_config_file_path) + + def __load_config(self, path: str) -> None: + ''' + Loads the model configuration from a file. + + Args: + path: The path to the configuration file. + ''' + conf = OmegaConf.load(path) + + for lang, lang_conf in conf.items(): + for metric_name, metric_conf in lang_conf.items(): + # check model availbility, if key not in conf + # omega conf will return None in default + assert isinstance(lang, str) + self.__set_model_for_metric(language=lang, + metric=metric_name, + **metric_conf) + print('Configuration Load Succeeded!') + + @lru_cache + def fetch_model( + self, language: str, metric: str + ) -> Union[Tuple[AutoTokenizer, AutoModelForSequenceClassification], Tuple[ + AutoTokenizer, AutoModelForSeq2SeqLM], SentenceTransformer]: + ''' + Return the model (and if applicable, the tokenizer) used for the given + metric and language. + + Args: + language: The language for which to get the model + metric_type: The metric name + + Returns: + A (tokenizer, modle) tuple, or just the model depending on the + loader function. + ''' + if language in self.config: + if metric in self.config[language]: + # Deep copy the confguration so that changes to `config` would + # not affect the original `self.config`. + config = deepcopy(self.config[language][metric]) + # Get model loader function + loader_func = config.pop('loader_func') + loader = LOADER_MAP[loader_func] + # Call the loader function with the model_name, tokenizer_name + # (optional), and revision (optional) as arguments + return loader(**config) + else: + raise KeyError(f'Metric {metric} not supported yet.') + else: + raise KeyError(f'Language {language} not supported yet') + + @staticmethod + def validate_config(config, language='all', metric='all') -> None: + ''' + Validate configuration. + + Args: + config: The configuration dictionary to validate. + language: The name of the language. Defaults to 'all'. + metric: The name of the metric. Defaults to 'all'. + ''' + config = deepcopy(config) + for lang, lang_setting in config.items(): + if language != 'all' and lang != language: + continue + for metric_name, model_setting in lang_setting.items(): + if metric != 'all' and metric_name != metric: + continue + + # Check that the model name and loader function are set + if 'model_name' not in model_setting: + raise KeyError( + f'{lang} metrics {metric_name} need a model, but found None!' # NOQA:E501 + ) + if 'loader_func' not in model_setting: + raise KeyError( + f'Metrics {metric_name} need a loader, but found None!' # NOQA:E501 + ) + loader_func = model_setting.get('loader_func') + if loader_func not in VALID_LOADER_FUNCTION: + raise ValueError( + f'loader type should in {VALID_LOADER_FUNCTION}') + + # Check model availability with revision if specified. + model_name = model_setting.get('model_name') + model_revision = model_setting.get('model_revision') + if not check_model_availability(model_name, model_revision): + raise ValueError( + f'Cannot find {model_name} with {model_revision} at Huggingface Hub' # NOQA:E501 + ) + # Check tokenizer availability with revision if specified. + tokenizer_name = model_setting.get('tokenizer_name') + if tokenizer_name is not None and tokenizer_name != model_name: + tokenizer_revision = model_setting.get('tokenizer_revision') + if not check_model_availability(tokenizer_name, + tokenizer_revision): + raise ValueError( + f'Cannot find {tokenizer_name} with {tokenizer_revision} ay Huggingface Hub' # NOQA:E501 + ) + + def __set_model_for_metric(self, language: str, metric: str, + model_name: str, loader_func: str, + **kwargs) -> None: + ''' + Set model for specified metric in specified language. + + Args: + language: The name of the language + metric: The name of the evaluation metric + model_name: The name of the model + loader_func: The loader function of the model + tokenizer_name: (Optional) The name of the tokenizer + model_revision: (Optional) A version string of the model. If not + specified, load the latest model by default. + tokenizer_revision: (Optional) A version string of the tokenizer. If + not specified, load the latest tokenizer by default. + ''' + config_copy = deepcopy(self.config) + try: + if language not in VALID_LANGUAGE: + raise KeyError('Language {language} not supported yet') + + if metric not in VALID_METRICS: + raise KeyError( + f'Metric {metric} not supported for language {language} yet' + ) + + # Initialize the configuration for the language and metric if it + # doesn't exist + if self.config.get(language) is None: + self.config[language] = {} + if self.config.get(language).get(metric) is None: + self.config[language][metric] = {} + + detail_config = self.config[language][metric] + # Set the loader function and model name + detail_config['loader_func'] = loader_func + detail_config['model_name'] = model_name + + # If tokenizer_name is different from model_name + tokenizer_name = kwargs.get('tokenizer_name') + if tokenizer_name: + detail_config['tokenizer_name'] = tokenizer_name + # If model's revision is pinned + model_revision = kwargs.get('model_revision') + if model_revision: + detail_config['model_revision'] = model_revision + # If tokenizer's revision is pinned + tokenizer_revision = kwargs.get('tokenizer_revision') + if tokenizer_revision: + detail_config['tokenizer_revision'] = tokenizer_revision + # Validate the change + ModelManager.validate_config(self.config, + language=language, + metric=metric) + # Clear the LRU cache to make the config change reflected + # immediately + self.fetch_model.cache_clear() + except (ValueError, KeyError) as err: + # If an error occurred, restore the original configuration + self.config = config_copy + raise err + + def list_current_model_in_use(self, language='all', metric='all') -> None: + ''' + List the models currently in use. + + Args: + language: The abbrevation name of language + metric: The evaluation metric name + ''' + df = pd.DataFrame.from_records( + [(lang, metric_name, key, value) + for lang, lang_model_settings in self.config.items() + for metric_name, model_settings in lang_model_settings.items() + for key, value in model_settings.items()], + columns=['language', 'metric_name', 'attribute', 'value']) + # The code below would generate a dataframe: + # |index| language | metric_name | loader | model_name | revision | + # |.....|..........|.............|........|............|..........| + df_pivot = df.pivot_table(index=['language', 'metric_name'], + columns="attribute", + values="value", + aggfunc='first').reset_index().rename_axis( + None, axis=1) + df_pivot.columns = [ + 'language', 'metric_name', 'loader', 'model_name', 'revision' + ] + + if language == 'all' and metric == 'all': + print( + tabulate( + df_pivot, # type: ignore + headers=df_pivot.columns, # type: ignore + tablefmt="github")) + else: + if language != "all": + df_pivot = df_pivot.loc[df_pivot.language == language] + if metric != 'all': + df_pivot = df_pivot.loc[df_pivot.metric_name == metric] + print( + tabulate( + df_pivot, # type: ignore + headers=df_pivot.columns, # type: ignore + tablefmt="github")) diff --git a/src/langcheck/metrics/model_manager/config/metric_config.yaml b/src/langcheck/metrics/model_manager/config/metric_config.yaml new file mode 100644 index 00000000..470b1843 --- /dev/null +++ b/src/langcheck/metrics/model_manager/config/metric_config.yaml @@ -0,0 +1,26 @@ +# LANG: +# METRIC_NAME: +# model_name: str +# model_revision: str (optional) +# tokenizer_name: str (optional) +# tokenizer_revision: str (optional) +# loader_func: str +zh: + semantic_similarity: + model_name: BAAI/bge-base-zh-v1.5 + model_revision: f03589c + loader_func: load_sentence_transformers + + sentiment: + model_name: IDEA-CCNL/Erlangshen-Roberta-110M-Sentiment + loader_func: load_auto_model_for_text_classification + + toxicity: + model_name: alibaba-pai/pai-bert-base-zh-llm-risk-detection + model_revision: 0a61c79744cb0173216f015ffecc1ea81c4e0229 + loader_func: load_auto_model_for_text_classification + + factual_consistency: + model_name: Helsinki-NLP/opus-mt-zh-en + model_revision: cf109095479db38d6df799875e34039d4938aaa6 + loader_func: load_auto_model_for_seq2seq \ No newline at end of file diff --git a/src/langcheck/metrics/zh/reference_based_text_quality.py b/src/langcheck/metrics/zh/reference_based_text_quality.py index 7821ed7c..01ee7e8d 100644 --- a/src/langcheck/metrics/zh/reference_based_text_quality.py +++ b/src/langcheck/metrics/zh/reference_based_text_quality.py @@ -90,14 +90,12 @@ def semantic_similarity( openai_args) metric_value.language = 'zh' return metric_value + # lazy import + from langcheck.metrics.model_manager import manager + model = manager.fetch_model(language='zh', metric="semantic_similarity") - # According to the C-MTEB Benchmark - # (https://github.com/FlagOpen/FlagEmbedding/tree/master/C_MTEB) - # the 3 models of different sizes provided BAAI are the best on the - # embedding task - # Ref: https://huggingface.co/BAAI/bge-base-zh-v1.5 - # Using this model, it is hard to find two sentence where cos_sim < 0.25. - model = SentenceTransformer('BAAI/bge-base-zh-v1.5') + # For type checking + assert isinstance(model, SentenceTransformer) generated_embeddings = model.encode(generated_outputs) reference_embeddings = model.encode(reference_outputs) cosine_scores = util.pairwise_cos_sim( diff --git a/src/langcheck/metrics/zh/reference_free_text_quality.py b/src/langcheck/metrics/zh/reference_free_text_quality.py index bab490a6..ada82bcc 100644 --- a/src/langcheck/metrics/zh/reference_free_text_quality.py +++ b/src/langcheck/metrics/zh/reference_free_text_quality.py @@ -1,28 +1,17 @@ from __future__ import annotations -import pickle -from math import e -from pathlib import PosixPath from typing import Dict, List, Optional import hanlp -import regex as re -import torch from openai import OpenAI from transformers.pipelines import pipeline -from langcheck._handle_logs import _handle_logging_level from langcheck.metrics._validation import validate_parameters_reference_free -from langcheck.metrics.en.reference_free_text_quality import (_fluency_openai, - _toxicity_openai) +from langcheck.metrics.en.reference_free_text_quality import _toxicity_openai from langcheck.metrics.en.reference_free_text_quality import \ sentiment as en_sentiment from langcheck.metrics.metric_value import MetricValue -_sentiment_model_path = 'IDEA-CCNL/Erlangshen-Roberta-110M-Sentiment' # NOQA: E501 - -_toxicity_model_path = "alibaba-pai/pai-bert-base-zh-llm-risk-detection" - def sentiment( generated_outputs: List[str] | str, @@ -87,12 +76,14 @@ def sentiment( metric_value.language = 'zh' return metric_value - global _sentiment_model_path - - _sentiment_pipeline = pipeline( - 'sentiment-analysis', model=_sentiment_model_path - ) # type: ignore[reportGeneralTypeIssues] # NOQA: E501 # {0:"Negative", 1:'Positive'} + from langcheck.metrics.model_manager import manager + tokenizer, model = manager.fetch_model(language='zh', metric='sentiment') + _sentiment_pipeline = pipeline( + 'sentiment-analysis', + model=model, # type: ignore[reportGeneralTypeIssues] + tokenizer=tokenizer # type: ignore[reportGeneralTypeIssues] + ) _model_id2label = _sentiment_pipeline.model.config.id2label _predict_result = _sentiment_pipeline( generated_outputs @@ -207,13 +198,15 @@ def _toxicity_local(generated_outputs: List[str]) -> List[float]: Returns: A list of scores ''' - global _toxicity_model_path # this pipeline output predict probability for each text on each label. # the output format is List[List[Dict(str)]] - _toxicity_pipeline = pipeline('text-classification', - model=_toxicity_model_path, - top_k=5) - + from langcheck.metrics.model_manager import manager + tokenizer, model = manager.fetch_model(language='zh', metric="toxicity") + _toxicity_pipeline = pipeline( + 'text-classification', + model=model, # type: ignore[reportOptionalIterable] + tokenizer=tokenizer, # type: ignore[reportOptionalIterable] + top_k=5) # {'Normal': 0, 'Pulp': 1, 'Sex': 2, 'Other Risk': 3, 'Adult': 4} _model_id2label = _toxicity_pipeline.model.config.id2label _predict_results = _toxicity_pipeline( diff --git a/src/langcheck/metrics/zh/source_based_text_quality.py b/src/langcheck/metrics/zh/source_based_text_quality.py index 239f583b..d2cdfd43 100644 --- a/src/langcheck/metrics/zh/source_based_text_quality.py +++ b/src/langcheck/metrics/zh/source_based_text_quality.py @@ -4,16 +4,12 @@ from openai import OpenAI from transformers.pipelines import pipeline -from transformers.pipelines.base import Pipeline from langcheck.metrics._validation import validate_parameters_source_based from langcheck.metrics.en.source_based_text_quality import \ factual_consistency as en_factual_consistency from langcheck.metrics.metric_value import MetricValue -_factual_consistency_translation_model_path = 'Helsinki-NLP/opus-mt-zh-en' -_factual_consistency_translation_pipeline: Pipeline | None = None - def factual_consistency( generated_outputs: List[str] | str, @@ -84,10 +80,11 @@ def factual_consistency( metric_value.language = 'zh' return metric_value - global _factual_consistency_translation_pipeline - if _factual_consistency_translation_pipeline is None: - _factual_consistency_translation_pipeline = pipeline( - 'translation', model=_factual_consistency_translation_model_path) + from langcheck.metrics.model_manager import manager + tokenizer, model = manager.fetch_model(language='zh', + metric='factual_consistency') + _factual_consistency_translation_pipeline = pipeline( + 'translation', model=model, tokenizer=tokenizer) # type: ignore # Translate the sources and generated outputs to English. # Currently, the type checks are not working for the pipeline, since @@ -96,14 +93,13 @@ def factual_consistency( cast(str, d['translation_text']) # type: ignore[reportGeneralTypeIssues] for d in _factual_consistency_translation_pipeline( - sources) # type: ignore[reportOptionalIterable] # NOQA: E501 + sources) # type: ignore[reportOptionalIterable] ] en_generated_outputs = [ cast(str, d['translation_text']) # type: ignore[reportGeneralTypeIssues] for d in _factual_consistency_translation_pipeline( - generated_outputs - ) # type: ignore[reportOptionalIterable] # NOQA: E501 + generated_outputs) # type: ignore[reportOptionalIterable] ] # Compute the factual consistency scores in English. factual_consistency_scores = en_factual_consistency( diff --git a/tests/metrics/model_manager/__init__.py b/tests/metrics/model_manager/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/metrics/model_manager/test_model_loader.py b/tests/metrics/model_manager/test_model_loader.py new file mode 100644 index 00000000..467f7f51 --- /dev/null +++ b/tests/metrics/model_manager/test_model_loader.py @@ -0,0 +1,86 @@ +from unittest.mock import MagicMock, patch + +import pytest +from sentence_transformers import SentenceTransformer +from transformers.models.auto.modeling_auto import ( + AutoModelForSeq2SeqLM, AutoModelForSequenceClassification) +from transformers.models.auto.tokenization_auto import AutoTokenizer + +from langcheck.metrics.model_manager._model_loader import ( + load_auto_model_for_seq2seq, load_auto_model_for_text_classification, + load_sentence_transformers) + +# Mock objects for AutoTokenizer and AutoModelForSeq2SeqLM +MockTokenizer = MagicMock(spec=AutoTokenizer) +MockSeq2SeqModel = MagicMock(spec=AutoModelForSeq2SeqLM) +MockSentenceTransModel = MagicMock(spec=SentenceTransformer) +MockSeqClassifcationModel = MagicMock(spec=AutoModelForSequenceClassification) + + +@pytest.mark.parametrize("model_name,tokenizer_name,revision", + [("t5-small", None, "main"), + ("t5-small", "t5-base", "main")]) +def test_load_auto_model_for_seq2seq(model_name, tokenizer_name, revision): + with patch('transformers.AutoTokenizer.from_pretrained', + return_value=MockTokenizer) as mock_tokenizer, \ + patch('transformers.AutoModelForSeq2SeqLM.from_pretrained', + return_value=MockSeq2SeqModel) as mock_model: + tokenizer, model = load_auto_model_for_seq2seq( + model_name=model_name, + tokenizer_name=tokenizer_name, + model_revision=revision, + tokenizer_revision=revision) + if tokenizer_name is None: + tokenizer_name = model_name + + mock_model.assert_called_once() + mock_tokenizer.assert_called_once() + # Assert that the returned objects are instances of the mocked objects + assert tokenizer == MockTokenizer, \ + "The returned tokenizer is not the expected mock object" + assert model == MockSeq2SeqModel, \ + "The returned model is not the expected mock object" + + +@pytest.mark.parametrize("model_name,tokenizer_name,revision", + [("bert-base-uncased", None, "main"), + ("bert-base-uncased", "bert-large-uncased", "main")]) +def test_load_auto_model_for_text_classification(model_name, tokenizer_name, + revision): + with patch('transformers.AutoTokenizer.from_pretrained', + return_value=MockTokenizer) as mock_tokenizer, \ + patch('transformers.AutoModelForSequenceClassification.from_pretrained', # NOQA:E501 + return_value=MockSeqClassifcationModel) as mock_model: + tokenizer, model = load_auto_model_for_text_classification( + model_name=model_name, + tokenizer_name=tokenizer_name, + model_revision=revision, + tokenizer_revision=revision) + if tokenizer_name is None: + tokenizer_name = model_name + + mock_model.assert_called_once() + mock_tokenizer.assert_called_once() + # Assert that the returned objects are instances of the mocked objects + assert tokenizer == MockTokenizer, \ + "The returned tokenizer is not the expected mock object" + assert model == MockSeqClassifcationModel, \ + "The returned model is not the expected mock object" + + +@pytest.mark.parametrize("model_name,tokenizer_name,revision", + [("all-MiniLM-L6-v2", None, "main"), + ("all-MiniLM-L6-v2", "all-mpnet-base-v2", "main")]) +def test_load_sentence_transformers(model_name, tokenizer_name, revision): + with patch.object(SentenceTransformer, '__init__', + return_value=None) as mock_init: + model = load_sentence_transformers(model_name=model_name, + tokenizer_name=tokenizer_name, + model_revision=revision, + tokenizer_revision=revision) + # Check if the model was loaded correctly + mock_init.assert_called_once_with(model_name) + + # Assert that the returned objects are instances of the mocked objects + assert isinstance(model, SentenceTransformer), \ + "The returned model is not the expected mock object" diff --git a/tests/metrics/model_manager/test_model_manager.py b/tests/metrics/model_manager/test_model_manager.py new file mode 100644 index 00000000..719278f5 --- /dev/null +++ b/tests/metrics/model_manager/test_model_manager.py @@ -0,0 +1,111 @@ +from unittest.mock import MagicMock, patch + +import pytest +import requests +from omegaconf import OmegaConf + +from langcheck.metrics.model_manager import _model_management +from langcheck.metrics.model_manager._model_management import ( + ModelManager, check_model_availability) + + +@pytest.fixture +def temp_config_path(tmp_path) -> str: + ''' + Fixture that creates a temporary configuration file for testing. + + Args: + tmp_path: A unique temporary directory path provided by pytest. + + Returns: + The path to the temporary configuration file. + ''' + config = ''' + zh: + toxicity: + model_name: alibaba-pai/pai-bert-base-zh-llm-risk-detection + loader_func: load_auto_model_for_text_classification + ja: + toxicity: + model_name: Alnusjaponica/toxicity-score-multi-classification + model_revision: bc7a465029744889c8252ee858ab04ab9efdb0e7 + tokenizer_name: line-corporation/line-distilbert-base-japanese + tokenizer_revision: 93bd4811608eecb95ffaaba957646efd9a909cc8 + loader_func: load_auto_model_for_text_classification + ''' + config_path = tmp_path / "metric_config.yaml" + config_path.write_text(config) + return str(config_path) + + +@pytest.fixture +def mock_model_manager(temp_config_path): + ''' + Fixture that creates a mock ModelManager for testing. + + The ModelManager is patched to use the temporary configuration file + created by the temp_config_path fixture, and to always return True + when checking model availability. + + Args: + temp_config_path: The path to the temporary configuration file. + + Returns: + The mock ModelManager. + ''' + with patch("os.path.join", return_value=temp_config_path), \ + patch('langcheck.metrics.model_manager._model_management.check_model_availability', # NOQA:E501 + return_value=True), \ + patch.object(_model_management, 'VALID_LANGUAGE', ['ja', 'zh']): + model_manager = ModelManager() + return model_manager + + +@pytest.mark.parametrize( + "model_name,revision, status_code", + [("bert-base-uncased", "", "200"), ("bert-base-uncased", None, "200"), + ("bert-base-uncased", "main", "200"), + ("bert-base-uncased", "a265f77", "200"), + ("bert-base-uncased", "a265f773a47193eed794233aa2a0f0bb6d3eaa63", "200"), + pytest.param( + "bert-base-uncased", "a265f78", "404", marks=pytest.mark.xfail), + pytest.param("", "0e9f4", "404", marks=pytest.mark.xfail), + pytest.param("terb-base-uncased", "", "404", marks=pytest.mark.xfail)], +) +@patch("requests.get") +def test_check_model_availability(mock_get, model_name, revision, status_code): + mock_get.return_value.status_code = status_code + available = check_model_availability(model_name, revision) + assert available is (status_code == requests.codes.OK) + + +def test_model_manager_initiation(mock_model_manager): + mock_config = mock_model_manager.config + assert "toxicity" in mock_config["zh"] + assert mock_config["zh"]["toxicity"]["model_name"] == \ + "alibaba-pai/pai-bert-base-zh-llm-risk-detection" + assert mock_config["zh"]["toxicity"]["loader_func"] == \ + "load_auto_model_for_text_classification" + + assert "toxicity" in mock_config["ja"] + assert mock_config["ja"]["toxicity"]["model_name"] ==\ + "Alnusjaponica/toxicity-score-multi-classification" + assert mock_config["ja"]["toxicity"]["model_revision"] ==\ + "bc7a465029744889c8252ee858ab04ab9efdb0e7" + assert mock_config["ja"]["toxicity"]["tokenizer_name"] ==\ + "line-corporation/line-distilbert-base-japanese" + assert mock_config["ja"]["toxicity"]["tokenizer_revision"] ==\ + "93bd4811608eecb95ffaaba957646efd9a909cc8" + assert mock_config["ja"]["toxicity"]["loader_func"] ==\ + "load_auto_model_for_text_classification" + + +def test_model_manager_fetch_model(mock_model_manager): + with \ + patch.dict( + 'langcheck.metrics.model_manager._model_management.LOADER_MAP', + {'load_auto_model_for_text_classification': MagicMock()}): + model = mock_model_manager.fetch_model(language='zh', metric='toxicity') + assert model is not None + model = mock_model_manager.fetch_model(language='ja', metric='toxicity') + assert model is not None