Skip to content

Commit

Permalink
Merge pull request #73 from Vela-zz/zh_model_config
Browse files Browse the repository at this point in the history
model config manager class
  • Loading branch information
yosukehigashi authored Feb 28, 2024
2 parents 78abbd7 + b1d5d76 commit 14fa90e
Show file tree
Hide file tree
Showing 13 changed files with 642 additions and 42 deletions.
13 changes: 13 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@ jobs:
pip install --upgrade pip
pip install .[dev]
# Remove unneeded system libraries to maximize disk space
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
# https://github.com/actions/virtual-environments/issues/2840#issuecomment-790492173
- name: Maximize disk space
run: |
echo "Available disk space (before):"
df -h
sudo rm -rf /usr/share/dotnet
sudo rm -rf /opt/ghc
sudo rm -rf /usr/local/lib/android
echo "Available disk space (after):"
df -h
# Run integration tests
- name: Test
run: |
Expand Down
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ dependencies = [
'tokenizers >= 0.13.2; python_version >= "3.11"', # See https://github.com/citadel-ai/langcheck/pull/45
'torch >= 2',
'transformers >= 4.6',
"unidic-lite >= 1.0.1" # For tokenizer of metrics.ja.toxicity()
"unidic-lite >= 1.0.1", # For tokenizer of metrics.ja.toxicity()
"tabulate >= 0.9.0", # For model manager paint table
"omegaconf >= 2.3.0" # For model manager paint table
]
requires-python = ">=3.8"

Expand Down Expand Up @@ -80,3 +82,10 @@ ignore = [
markers = [
"optional: marks tests as optional",
]
disable_test_id_escaping_and_forfeit_all_rights_to_community_support = true

[tool.setuptools.packages.find]
where = ["src"]

[tool.setuptools.package-data]
langcheck = ["metrics/model_manager/config/*.yaml"]
3 changes: 2 additions & 1 deletion src/langcheck/metrics/en/_detoxify.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import List, Tuple

import torch
from transformers import BertForSequenceClassification, BertTokenizer
from transformers.models.bert.modeling_bert import BertForSequenceClassification
from transformers.models.bert.tokenization_bert import BertTokenizer


def load_checkpoint(
Expand Down
3 changes: 3 additions & 0 deletions src/langcheck/metrics/model_manager/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._model_management import ModelManager

manager = ModelManager()
96 changes: 96 additions & 0 deletions src/langcheck/metrics/model_manager/_model_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Optional, Tuple

from sentence_transformers import SentenceTransformer
from transformers.models.auto.modeling_auto import (
AutoModelForSeq2SeqLM, AutoModelForSequenceClassification)
from transformers.models.auto.tokenization_auto import AutoTokenizer


def load_sentence_transformers(
model_name: str,
model_revision: Optional[str] = None,
tokenizer_name: Optional[str] = None,
tokenizer_revision: Optional[str] = None) -> SentenceTransformer:
'''
Loads a SentenceTransformer model.
This function currently does not support specifying a tokenizer or a
revision. If these arguments are provided, a warning message will be
printed.
Args:
model_name: The name of the SentenceTransformer model to load.
tokenizer_name: The name of the tokenizer to use. Currently not
supported.
model_revision: The model revision to load. Currently not supported.
tokenizerl_revision: The tokenizedr revision to load. Currently not
supported.
Returns:
model: The loaded SentenceTransformer model.
'''
if model_revision is not None or tokenizer_revision is not None:
print("Warning: Specifying a revision is not currently supported.")
if tokenizer_name is not None:
print("Warning: Customizing the tokenizer is not currently supported.")

model = SentenceTransformer(model_name)
return model


def load_auto_model_for_text_classification(
model_name: str,
model_revision: Optional[str] = None,
tokenizer_name: Optional[str] = None,
tokenizer_revision: Optional[str] = None
) -> Tuple[AutoTokenizer, AutoModelForSequenceClassification]:
'''
Loads a sequence classification model and its tokenizer.
Args:
model_name: The name of the sequence-classification model to load.
tokenizer_name: The name of the tokenizer to load. If None, the
tokenizer associated with the model will be loaded.
model_revision: The model revision to load.
tokenizer_revision: the tokenizer revision to load.
Returns:
tokenizer: The loaded tokenizer.
model: The loaded sequence classification model.
'''
if tokenizer_name is None:
tokenizer_name = model_name
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name,
revision=tokenizer_revision)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, revision=model_revision)
return tokenizer, model # type: ignore


def load_auto_model_for_seq2seq(
model_name: str,
model_revision: Optional[str] = None,
tokenizer_name: Optional[str] = None,
tokenizer_revision: Optional[str] = None
) -> Tuple[AutoTokenizer, AutoModelForSeq2SeqLM]:
'''
Loads a sequence-to-sequence model and its tokenizer.
Args:
model_name: The name of the sequence-classification model to load.
tokenizer_name: The name of the tokenizer to load. If None, the
tokenizer associated with the model will be loaded.
model_revision: The model revision to load.
tokenizer_revision: the tokenizer revision to load
Returns:
tokenizer: The loaded tokenizer.
model: The loaded sequence-to-sequence model.
'''
if tokenizer_name is None:
tokenizer_name = model_name
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name,
revision=tokenizer_revision)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name,
revision=model_revision)
return tokenizer, model # type: ignore
Loading

0 comments on commit 14fa90e

Please sign in to comment.