Skip to content

Commit

Permalink
Merge branch 'zh_model_config' of https://github.com/vela-zz/langcheck
Browse files Browse the repository at this point in the history
…into pr/Vela-zz/73
  • Loading branch information
yosukehigashi committed Feb 7, 2024
2 parents ac59b6b + a83d6a3 commit 492396a
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/langcheck/metrics/model_manager/_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def load_sentence_transformers(
if tokenizer_name is not None:
print("Tokenizer customize not supported in Sentence-Transformers yet.")

return SentenceTransformer(model_name)
model = SentenceTransformer(model_name)
return model


def load_auto_model_for_text_classification(
Expand Down
3 changes: 2 additions & 1 deletion src/langcheck/metrics/model_manager/_model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,4 +227,5 @@ def list_current_model_in_use(self, language='all', metric='all'):
if metric != 'all':
df_pivot = df_pivot.loc[df_pivot.metric_name == metric]
print(
tabulate(df_pivot, headers=df_pivot.columns, tablefmt="github")) # type: ignore # NOQA:E501
tabulate(df_pivot, headers=df_pivot.columns, # type: ignore # NOQA:E501
tablefmt="github"))
Empty file.
89 changes: 89 additions & 0 deletions tests/metrics/model_manager/test_model_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest
from unittest.mock import patch, MagicMock
from sentence_transformers import SentenceTransformer
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.models.auto.modeling_auto \
import (AutoModelForSeq2SeqLM, AutoModelForSequenceClassification)
from langcheck.metrics.model_manager._model_loader \
import (load_auto_model_for_seq2seq,
load_auto_model_for_text_classification,
load_sentence_transformers)

# Mock objects for AutoTokenizer and AutoModelForSeq2SeqLM
MockTokenizer = MagicMock(spec=AutoTokenizer)
MockSeq2SeqModel = MagicMock(spec=AutoModelForSeq2SeqLM)
MockSentenceTransModel = MagicMock(spec=SentenceTransformer)
MockSeqClassifcationModel = MagicMock(spec=AutoModelForSequenceClassification)


@pytest.mark.parametrize("model_name,tokenizer_name,revision", [
("t5-small", None, "main"),
("t5-small", "t5-base", "main")
])
def test_load_auto_model_for_seq2seq(model_name, tokenizer_name, revision):
with patch('transformers.AutoTokenizer.from_pretrained',
return_value=MockTokenizer) as mock_tokenizer, \
patch('transformers.AutoModelForSeq2SeqLM.from_pretrained',
return_value=MockSeq2SeqModel) as mock_model:
tokenizer, model = load_auto_model_for_seq2seq(model_name,
tokenizer_name, revision)

# Check if the tokenizer was loaded correctly
if tokenizer_name is None:
tokenizer_name = model_name
mock_tokenizer.assert_called_once_with(tokenizer_name,
revision=revision)

# Check if the model was loaded correctly
mock_model.assert_called_once_with(model_name,
revision=revision)

# Assert that the returned objects are instances of the mocked objects
assert tokenizer == MockTokenizer, \
"The returned tokenizer is not the expected mock object"
assert model == MockSeq2SeqModel, \
"The returned model is not the expected mock object"


@pytest.mark.parametrize("model_name,tokenizer_name,revision", [
("bert-base-uncased", None, "main"),
("bert-base-uncased", "bert-large-uncased", "main")
])
def test_load_auto_model_for_text_classification(model_name, tokenizer_name, revision): # NOQA:E501
with patch('transformers.AutoTokenizer.from_pretrained',
return_value=MockTokenizer) as mock_tokenizer, \
patch('transformers.AutoModelForSequenceClassification.from_pretrained',
return_value=MockSeqClassifcationModel) as mock_model:
tokenizer, model = load_auto_model_for_text_classification(model_name,
tokenizer_name, revision) # NOQA:E501

# Check if the tokenizer was loaded correctly
if tokenizer_name is None:
tokenizer_name = model_name
mock_tokenizer.assert_called_once_with(tokenizer_name,
revision=revision)

# Check if the model was loaded correctly
mock_model.assert_called_once_with(model_name,
revision=revision)

# Assert that the returned objects are instances of the mocked objects
assert tokenizer == MockTokenizer, \
"The returned tokenizer is not the expected mock object"
assert model == MockSeqClassifcationModel, \
"The returned model is not the expected mock object"


@pytest.mark.parametrize("model_name,tokenizer_name,revision", [
("all-MiniLM-L6-v2", None, "main"),
("all-MiniLM-L6-v2", "all-mpnet-base-v2", "main")
])
def test_load_sentence_transformers(model_name, tokenizer_name, revision):
with patch.object(SentenceTransformer, '__init__', return_value=None) as mock_init:
model = load_sentence_transformers(model_name, tokenizer_name, revision)
# Check if the model was loaded correctly
mock_init.assert_called_once_with(model_name)

# Assert that the returned objects are instances of the mocked objects
assert isinstance(model, SentenceTransformer), \
"The returned model is not the expected mock object"
Empty file.

0 comments on commit 492396a

Please sign in to comment.