Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model config manager class #73

Merged
merged 70 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
87016a2
implent a model config manager class
Vela-zz Dec 10, 2023
2bc3b75
add test case for model management
Vela-zz Dec 10, 2023
34e49fc
apply format suggestion
Vela-zz Dec 10, 2023
083c612
apply format suggestion
Vela-zz Dec 10, 2023
2cdf43c
pydoc update & fix test case
Vela-zz Dec 10, 2023
99fe02e
implement a model config manager class
Vela-zz Dec 10, 2023
2843c88
add test case for model management
Vela-zz Dec 10, 2023
49983cd
apply format suggestion
Vela-zz Dec 10, 2023
63aa6e6
pydoc update & fix test case
Vela-zz Dec 10, 2023
66ef0c6
apply format suggestion
Vela-zz Dec 10, 2023
6cd3755
Merge branch 'zh_model_config' of https://github.com/Vela-zz/langchec…
Vela-zz Dec 10, 2023
b89ed30
pydoc update & fix test case
Vela-zz Dec 10, 2023
1a2b720
Merge branch 'zh_model_config' of https://github.com/Vela-zz/langchec…
Vela-zz Dec 23, 2023
2506ece
add model loader
Vela-zz Dec 23, 2023
57ea217
re-implent a model manager class
Vela-zz Dec 25, 2023
3bf196c
add update_metrics_for_model method
Vela-zz Dec 25, 2023
bb70f64
apply format suggestion
Vela-zz Dec 25, 2023
db15318
clean up model loader docstrings
yosukehigashi Dec 27, 2023
ddac3cf
fix format
yosukehigashi Dec 27, 2023
6b2a382
clean up docstrings in model management
yosukehigashi Dec 27, 2023
0e23c39
make self.config not None
yosukehigashi Dec 27, 2023
a1fd972
remove unnecessary noqa tags
yosukehigashi Dec 27, 2023
e80285d
clean up comments
yosukehigashi Dec 27, 2023
d841711
fix fetch_model format
yosukehigashi Dec 27, 2023
2aacbf2
fix ref based and source based format
yosukehigashi Dec 27, 2023
6d0a530
fix format in ref free
yosukehigashi Dec 27, 2023
b2805d6
add dependencies
Vela-zz Dec 27, 2023
bc443a6
fix model manager plot table problem
Vela-zz Dec 27, 2023
4433da1
fix typo mistakes
Vela-zz Dec 28, 2023
c540f12
move and update model manager
Vela-zz Feb 5, 2024
ea7564b
use yaml file for read ease
Vela-zz Feb 5, 2024
a1120db
apply manager changes to zh metric
Vela-zz Feb 5, 2024
b1718d6
update environment settings and delete test case
Vela-zz Feb 5, 2024
da11f46
apply format check suggestions
Vela-zz Feb 5, 2024
a83d6a3
add test case for model loader.
Vela-zz Feb 5, 2024
ac59b6b
add package data to pyproject.toml
yosukehigashi Feb 7, 2024
492396a
Merge branch 'zh_model_config' of https://github.com/vela-zz/langchec…
yosukehigashi Feb 7, 2024
e4091b3
package-data fix
yosukehigashi Feb 7, 2024
37c3884
try again
yosukehigashi Feb 7, 2024
bd3ab62
remove global value in metric
Vela-zz Feb 7, 2024
d8f20bc
clean load_sentence_transformers comments
yosukehigashi Feb 9, 2024
c46645b
clean load_auto_model_for_text_classification docstring
yosukehigashi Feb 9, 2024
ff90381
clean load_auto_model_for_seq2seq docstring
yosukehigashi Feb 9, 2024
e35b66e
clean up model loader imports
yosukehigashi Feb 9, 2024
a959633
clean up __load_config
yosukehigashi Feb 9, 2024
3240129
add comment for fetch_model
yosukehigashi Feb 9, 2024
78b296e
clean __set_model_for_metric
yosukehigashi Feb 9, 2024
8b247fd
minor cleanup
yosukehigashi Feb 9, 2024
42ccc66
clean validate_config
yosukehigashi Feb 9, 2024
f4bf665
remove unused import
yosukehigashi Feb 9, 2024
6f33c0a
Merge branch 'zh_model_config' of https://github.com/Vela-zz/langchec…
Vela-zz Feb 9, 2024
e7d42e9
add test case for model manager class
Vela-zz Feb 16, 2024
1086826
remove global value in metrics
Vela-zz Feb 16, 2024
89d9aaa
apply format suggestions
Vela-zz Feb 16, 2024
a09dbd3
make jp char and zh char show formally in test, not unicode
Vela-zz Feb 16, 2024
daeb706
fix import error in en detoxify raised by pyright
Vela-zz Feb 19, 2024
64f7e95
apply format check suggestion
Vela-zz Feb 19, 2024
6155623
apply format check suggestions and remove useless import
Vela-zz Feb 19, 2024
7f45a8c
remove unused imports
yosukehigashi Feb 20, 2024
0f15528
remove unused import
yosukehigashi Feb 20, 2024
5371f86
cleanup and docstrings
yosukehigashi Feb 20, 2024
57b864f
add tokenizer_revision for fine grained control
Vela-zz Feb 26, 2024
160a0d3
apply tokenizer_revision update to test case
Vela-zz Feb 26, 2024
10593fc
clean up docstring and comments
yosukehigashi Feb 28, 2024
7c770b0
fix typo
yosukehigashi Feb 28, 2024
2ce534f
specify which fields are optional in config
yosukehigashi Feb 28, 2024
86f55fe
removed unnecessary noqa
yosukehigashi Feb 28, 2024
f540159
fix yapf format
yosukehigashi Feb 28, 2024
fdd353e
fix yapf format
yosukehigashi Feb 28, 2024
b1d5d76
maximize disk space
yosukehigashi Feb 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@ jobs:
pip install --upgrade pip
pip install .[dev]
# Remove unneeded system libraries to maximize disk space
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
# https://github.com/actions/virtual-environments/issues/2840#issuecomment-790492173
- name: Maximize disk space
run: |
echo "Available disk space (before):"
df -h
sudo rm -rf /usr/share/dotnet
sudo rm -rf /opt/ghc
sudo rm -rf /usr/local/lib/android
echo "Available disk space (after):"
df -h
# Run integration tests
- name: Test
run: |
Expand Down
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ dependencies = [
'tokenizers >= 0.13.2; python_version >= "3.11"', # See https://github.com/citadel-ai/langcheck/pull/45
'torch >= 2',
'transformers >= 4.6',
"unidic-lite >= 1.0.1" # For tokenizer of metrics.ja.toxicity()
"unidic-lite >= 1.0.1", # For tokenizer of metrics.ja.toxicity()
"tabulate >= 0.9.0", # For model manager paint table
"omegaconf >= 2.3.0" # For model manager paint table
]
requires-python = ">=3.8"

Expand Down Expand Up @@ -80,3 +82,10 @@ ignore = [
markers = [
"optional: marks tests as optional",
]
disable_test_id_escaping_and_forfeit_all_rights_to_community_support = true

[tool.setuptools.packages.find]
where = ["src"]

[tool.setuptools.package-data]
langcheck = ["metrics/model_manager/config/*.yaml"]
3 changes: 2 additions & 1 deletion src/langcheck/metrics/en/_detoxify.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import List, Tuple

import torch
from transformers import BertForSequenceClassification, BertTokenizer
from transformers.models.bert.modeling_bert import BertForSequenceClassification
from transformers.models.bert.tokenization_bert import BertTokenizer


def load_checkpoint(
Expand Down
3 changes: 3 additions & 0 deletions src/langcheck/metrics/model_manager/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._model_management import ModelManager

manager = ModelManager()
96 changes: 96 additions & 0 deletions src/langcheck/metrics/model_manager/_model_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Optional, Tuple

from sentence_transformers import SentenceTransformer
from transformers.models.auto.modeling_auto import (
AutoModelForSeq2SeqLM, AutoModelForSequenceClassification)
from transformers.models.auto.tokenization_auto import AutoTokenizer


def load_sentence_transformers(
model_name: str,
model_revision: Optional[str] = None,
tokenizer_name: Optional[str] = None,
tokenizer_revision: Optional[str] = None) -> SentenceTransformer:
'''
Loads a SentenceTransformer model.

This function currently does not support specifying a tokenizer or a
revision. If these arguments are provided, a warning message will be
printed.

Args:
model_name: The name of the SentenceTransformer model to load.
tokenizer_name: The name of the tokenizer to use. Currently not
supported.
model_revision: The model revision to load. Currently not supported.
tokenizerl_revision: The tokenizedr revision to load. Currently not
supported.

Returns:
model: The loaded SentenceTransformer model.
'''
if model_revision is not None or tokenizer_revision is not None:
print("Warning: Specifying a revision is not currently supported.")
if tokenizer_name is not None:
print("Warning: Customizing the tokenizer is not currently supported.")

model = SentenceTransformer(model_name)
return model


def load_auto_model_for_text_classification(
model_name: str,
model_revision: Optional[str] = None,
tokenizer_name: Optional[str] = None,
tokenizer_revision: Optional[str] = None
) -> Tuple[AutoTokenizer, AutoModelForSequenceClassification]:
'''
Loads a sequence classification model and its tokenizer.

Args:
model_name: The name of the sequence-classification model to load.
tokenizer_name: The name of the tokenizer to load. If None, the
tokenizer associated with the model will be loaded.
model_revision: The model revision to load.
tokenizer_revision: the tokenizer revision to load.

Returns:
tokenizer: The loaded tokenizer.
model: The loaded sequence classification model.
'''
if tokenizer_name is None:
tokenizer_name = model_name
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name,
revision=tokenizer_revision)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, revision=model_revision)
return tokenizer, model # type: ignore


def load_auto_model_for_seq2seq(
model_name: str,
model_revision: Optional[str] = None,
tokenizer_name: Optional[str] = None,
tokenizer_revision: Optional[str] = None
) -> Tuple[AutoTokenizer, AutoModelForSeq2SeqLM]:
'''
Loads a sequence-to-sequence model and its tokenizer.

Args:
model_name: The name of the sequence-classification model to load.
tokenizer_name: The name of the tokenizer to load. If None, the
tokenizer associated with the model will be loaded.
model_revision: The model revision to load.
tokenizer_revision: the tokenizer revision to load

Returns:
tokenizer: The loaded tokenizer.
model: The loaded sequence-to-sequence model.
'''
if tokenizer_name is None:
tokenizer_name = model_name
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name,
revision=tokenizer_revision)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name,
revision=model_revision)
return tokenizer, model # type: ignore
Loading
Loading