Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adopt naming convention of transformers API #14

Merged
merged 9 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .github/workflows/check_code_quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ jobs:
source venv/bin/activate
pip install --upgrade pip
pip install .[quality]
- name: Check style with black
run: |
source venv/bin/activate
black --check .
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok you did it 😄

- name: Check style with ruff
run: |
source venv/bin/activate
Expand Down
6 changes: 2 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,10 @@ tpu-tgi:

# Run code quality checks
style_check:
black --check .
ruff .

style:
black .
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should remove black from check_code_quality.yml workflow too

ruff . --fix
ruff check . --fix

# Utilities to release to PyPi
build_dist_install_tools:
Expand All @@ -70,7 +68,7 @@ pypi_upload: ${PACKAGE_DIST} ${PACKAGE_WHEEL}

# Tests
test_installs:
python -m pip install .[tpu,tests]
python -m pip install .[tests]

tests: test_installs
python -m pytest -sv tests
Expand Down
3 changes: 2 additions & 1 deletion optimum/tpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .version import __version__, VERSION # noqa: F401
from .version import __version__, VERSION # noqa: F401
from .modeling import AutoModelForCausalLM # noqa: F401
8 changes: 4 additions & 4 deletions optimum/tpu/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
from typing import Any

from loguru import logger
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM as BaseAutoModelForCausalLM
from transformers.utils import is_accelerate_available


# TODO: For now TpuModelForCausalLM is just a shallow wrapper of
# AutoModelForCausalLM, later this could be replaced by a custom class.
class TpuModelForCausalLM(AutoModelForCausalLM):
class AutoModelForCausalLM(BaseAutoModelForCausalLM):

@classmethod
def from_pretrained(
Expand All @@ -46,11 +46,11 @@ def from_pretrained(
else:
device = "xla"
if is_accelerate_available():
model = AutoModelForCausalLM.from_pretrained(
model = BaseAutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, device_map=device, *model_args, **kwargs
)
else:
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
model = BaseAutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
model.to(device)
# Update config with specific data)
if task is not None or getattr(model.config, "task", None) is None:
Expand Down
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ dependencies = [
"loguru == 0.6.0"
]

[build-system]
requires = ["setuptools>=64", "setuptools_scm>=8"]
build-backend = "setuptools.build_meta"

[project.optional-dependencies]
tests = ["pytest", "safetensors"]
quality = ["black", "ruff", "isort",]
Expand All @@ -58,8 +62,8 @@ Documentation = "https://hf.co/docs/optimum/tpu"
Repository = "https://github.com/huggingface/optimum-tpu"
Issues = "https://github.com/huggingface/optimum-tpu/issues"

[tool.setuptools.dynamic]
version = {attr = "optimum.tpu.__version__"}
[tool.setuptools_scm]


[tool.setuptools.packages.find]
include = ["optimum.tpu"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from loguru import logger
from transformers import AutoTokenizer, PreTrainedTokenizerBase, StaticCache
from transformers.generation import GenerationConfig
from optimum.tpu.modeling import TpuModelForCausalLM
from optimum.tpu import AutoModelForCausalLM
from optimum.tpu.generation import TokenSelector

from .pb.generate_pb2 import (
Expand Down Expand Up @@ -301,7 +301,7 @@ class TpuGenerator(Generator):

def __init__(
self,
model: TpuModelForCausalLM,
model,
tokenizer: PreTrainedTokenizerBase,
):
self.model = model
Expand Down Expand Up @@ -633,7 +633,7 @@ def from_pretrained(
"""
logger.info("Loading model (this can take a few minutes).")
start = time.time()
model = TpuModelForCausalLM.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
end = time.time()
logger.info(f"Model successfully loaded in {end - start:.2f} s.")
tokenizer = AutoTokenizer.from_pretrained(model_path)
Expand Down
Loading