Skip to content

Commit

Permalink
Adopt naming convention of transformers API (#14)
Browse files Browse the repository at this point in the history
* Let's replace TPUModelForX with AutoModelForX to better match transformers

* Make AutoModelForX top-level import

* match new names

* Remove tpu from make pip install

* import version before anything else

* Let's use setuptools_scm backend to install deps

* Fix F401

* Let's just use ruff not black and ruff

* Again
  • Loading branch information
mfuntowicz committed Apr 8, 2024
1 parent f92066c commit cef8e0d
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 18 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/check_code_quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ jobs:
source venv/bin/activate
pip install --upgrade pip
pip install .[quality]
- name: Check style with black
run: |
source venv/bin/activate
black --check .
- name: Check style with ruff
run: |
source venv/bin/activate
Expand Down
6 changes: 2 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,10 @@ tpu-tgi:

# Run code quality checks
style_check:
black --check .
ruff .

style:
black .
ruff . --fix
ruff check . --fix

# Utilities to release to PyPi
build_dist_install_tools:
Expand All @@ -70,7 +68,7 @@ pypi_upload: ${PACKAGE_DIST} ${PACKAGE_WHEEL}

# Tests
test_installs:
python -m pip install .[tpu,tests]
python -m pip install .[tests]

tests: test_installs
python -m pytest -sv tests
Expand Down
3 changes: 2 additions & 1 deletion optimum/tpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .version import __version__, VERSION # noqa: F401
from .version import __version__, VERSION # noqa: F401
from .modeling import AutoModelForCausalLM # noqa: F401
8 changes: 4 additions & 4 deletions optimum/tpu/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
from typing import Any

from loguru import logger
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM as BaseAutoModelForCausalLM
from transformers.utils import is_accelerate_available


# TODO: For now TpuModelForCausalLM is just a shallow wrapper of
# AutoModelForCausalLM, later this could be replaced by a custom class.
class TpuModelForCausalLM(AutoModelForCausalLM):
class AutoModelForCausalLM(BaseAutoModelForCausalLM):

@classmethod
def from_pretrained(
Expand All @@ -46,11 +46,11 @@ def from_pretrained(
else:
device = "xla"
if is_accelerate_available():
model = AutoModelForCausalLM.from_pretrained(
model = BaseAutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, device_map=device, *model_args, **kwargs
)
else:
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
model = BaseAutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
model.to(device)
# Update config with specific data)
if task is not None or getattr(model.config, "task", None) is None:
Expand Down
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ dependencies = [
"loguru == 0.6.0"
]

[build-system]
requires = ["setuptools>=64", "setuptools_scm>=8"]
build-backend = "setuptools.build_meta"

[project.optional-dependencies]
tests = ["pytest", "safetensors"]
quality = ["black", "ruff", "isort", "hf_doc_builder @ git+https://github.com/huggingface/doc-builder.git"]
Expand All @@ -58,8 +62,8 @@ Documentation = "https://hf.co/docs/optimum/tpu"
Repository = "https://github.com/huggingface/optimum-tpu"
Issues = "https://github.com/huggingface/optimum-tpu/issues"

[tool.setuptools.dynamic]
version = {attr = "optimum.tpu.__version__"}
[tool.setuptools_scm]


[tool.setuptools.packages.find]
include = ["optimum.tpu"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from loguru import logger
from transformers import AutoTokenizer, PreTrainedTokenizerBase, StaticCache
from transformers.generation import GenerationConfig
from optimum.tpu.modeling import TpuModelForCausalLM
from optimum.tpu import AutoModelForCausalLM
from optimum.tpu.generation import TokenSelector

from .pb.generate_pb2 import (
Expand Down Expand Up @@ -301,7 +301,7 @@ class TpuGenerator(Generator):

def __init__(
self,
model: TpuModelForCausalLM,
model,
tokenizer: PreTrainedTokenizerBase,
):
self.model = model
Expand Down Expand Up @@ -633,7 +633,7 @@ def from_pretrained(
"""
logger.info("Loading model (this can take a few minutes).")
start = time.time()
model = TpuModelForCausalLM.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
end = time.time()
logger.info(f"Model successfully loaded in {end - start:.2f} s.")
tokenizer = AutoTokenizer.from_pretrained(model_path)
Expand Down

0 comments on commit cef8e0d

Please sign in to comment.