Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache en #26

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ test_installs:
python -m pip install .[tests]

tests: test_installs
python -m pytest -sv tests
python -m pytest -sv tests/single_thread
python -m pytest -sv tests/distributed

# Stand-alone TGI server for unit tests outside of TGI container
tgi_server:
Expand Down
8 changes: 5 additions & 3 deletions examples/text-generation/generation_gemma.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
#!/usr/bin/python

import torch
import time
import datetime
import os
import platform
import time
from typing import List

import torch
import torch_xla.core.xla_model as xm
from optimum.tpu.modeling import AutoModelForCausalLM
from transformers import AutoTokenizer, StaticCache

from optimum.tpu.modeling import AutoModelForCausalLM


os.environ["PJRT_DEVICE"] = "TPU"

Expand Down
6 changes: 4 additions & 2 deletions optimum/tpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .version import __version__, VERSION # noqa: F401
from .modeling import AutoModelForCausalLM # noqa: F401
from .cache import initialize_cache
from .model import fetch_model
from .modeling import AutoModelForCausalLM
from .version import VERSION, __version__
21 changes: 21 additions & 0 deletions optimum/tpu/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pathlib import Path
from typing import Union

import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr


def initialize_cache(path: Union[str, Path] = "~/.cache/optimum-tpu/"):
"""Initialize the cache for the XLA runtime.

Note that this will only initialize the cache on the master ordinal.

Args:
path (`str`, defaults to `~/.cache/optimum-tpu/`):
The path to the cache directory.
"""
# Resolve tilde in the path
path = Path(path).expanduser()
# It will be readonly only if the ordinal is not 0, i.e. not the master
readonly = xm.get_ordinal() != 0
xr.initialize_cache(str(path), readonly=readonly)
13 changes: 5 additions & 8 deletions optimum/tpu/distributed_model.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
# ruff: noqa: E402
import torch
import os
from enum import Enum
from typing import Dict

import torch
from loguru import logger


os.environ["PJRT_DEVICE"] = "TPU"

import torch.multiprocessing as mp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch.multiprocessing as mp
from transformers import PretrainedConfig

from optimum.tpu.modeling import AutoModelForCausalLM
from transformers import PretrainedConfig


class ModelCommand(Enum):
Expand All @@ -38,17 +40,14 @@ def config(self):

def send(self, command: ModelCommand, data: Dict = None):
# First wait until model is ready to receive commands
logger.debug(f" MM Command {command} waiting for model to be ready")
self.model_ready.wait()
self.model_ready.clear()

self.root_command[:] = [command, data]
self.root_bell.set()
logger.debug(f" MM Command {command} sent")
# wait again until model is ready, meaning command has been processed
self.model_ready.wait()
ret = self.output_data.get()
logger.debug(f" MM Command {command} output shape {ret.shape}")
return ret


Expand All @@ -66,10 +65,8 @@ def receive(self):
return self.root_command

def send(self, data: torch.Tensor):
logger.debug(f" MM Enqueueing data {data.shape}")
# Data needs to be moved to CPU before setting it
self.output_data.set(data.cpu())
logger.debug(" MM Enqueueing data done")

@property
def command_data(self):
Expand Down
3 changes: 2 additions & 1 deletion optimum/tpu/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from typing import Any

from loguru import logger
from transformers import AutoModelForCausalLM as BaseAutoModelForCausalLM, AutoConfig
from transformers import AutoConfig
from transformers import AutoModelForCausalLM as BaseAutoModelForCausalLM

from optimum.tpu.modeling_gemma import TpuGemmaForCausalLM

Expand Down
8 changes: 4 additions & 4 deletions optimum/tpu/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@
""" PyTorch Gemma model."""

import math
import re
import warnings
from typing import List, Optional, Tuple, Union
import re

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import (
Expand All @@ -34,6 +33,7 @@
)
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gemma.configuration_gemma import GemmaConfig
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from transformers.utils import (
add_start_docstrings,
Expand All @@ -44,15 +44,15 @@
replace_return_docstrings,
)
from transformers.utils.import_utils import is_torch_fx_available
from transformers.models.gemma.configuration_gemma import GemmaConfig

from optimum.tpu.xla_model_parallel import (
RowParallelLinear,
ColumnParallelLinear,
RowParallelLinear,
get_model_parallel_rank,
get_model_parallel_world_size,
)


if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
Expand Down
3 changes: 2 additions & 1 deletion optimum/tpu/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@

from pkg_resources import parse_version

__version__ = "0.1.0.dev2"

__version__ = "0.1.0a0"
VERSION = parse_version(__version__)
3 changes: 2 additions & 1 deletion optimum/tpu/xla_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from copy import deepcopy
from dataclasses import dataclass
import os
from typing import Callable, List, Optional, Tuple

import torch
Expand All @@ -26,6 +26,7 @@
import torch.nn.init as init
from torch.nn.parameter import Parameter


EPS = torch.finfo(torch.float32).eps

USE_CUDA = os.environ.get("USE_CUDA", False)
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,17 @@ line-length = 119
target-version = ['py38']
extend-exclude = '.ipynb'

[lint.ruff]
[tool.ruff]
# Never enforce `E501` (line length violations).
ignore = ["C901", "E501", "E741", "W605"]
select = ["C", "E", "F", "I", "W"]
lint.ignore = ["C901", "E501", "E741", "W605"]
lint.select = ["C", "E", "F", "I", "W"]
line-length = 119

# Ignore import violations in all `__init__.py` files.
[lint.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]

[lint.ruff.isort]
[tool.ruff.lint.isort]
lines-after-imports = 2
known-first-party = ["optimum.tpu"]

Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest


# See https://stackoverflow.com/a/61193490/217945 for run_slow
def pytest_addoption(parser):
parser.addoption(
Expand All @@ -18,4 +19,4 @@ def pytest_collection_modifyitems(config, items):
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
item.add_marker(skip_slow)
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from optimum.tpu.distributed_model import DistributedModel
from transformers import AutoTokenizer
import torch

import pytest
import torch
from transformers import AutoTokenizer

from optimum.tpu.distributed_model import DistributedModel


def sample_greedy(logits):
Expand Down
27 changes: 27 additions & 0 deletions tests/single_thread/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
from tempfile import TemporaryDirectory

import torch

from optimum.tpu import initialize_cache


def test_init_cache():
os.environ["PJRT_DEVICE"] = "TPU"
# This is just to make sure the model has been downloaded
with TemporaryDirectory() as tmp_dir:
cache_dir = os.path.join(tmp_dir, "cache")
initialize_cache(cache_dir)
assert not os.path.exists(cache_dir)

# Do some calculation that will trigger graph generation and caching
v1 = torch.ones((100, 200), device="xla")
v2 = torch.ones((200, 100), device="xla")
v3 = v1 @ v2
# Result is printed to avoid the optimizer to remove the computation
print(v3.max())

assert os.path.exists(cache_dir)
assert len(os.listdir(cache_dir)) > 0


Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import copy
import logging
import time
import os
import time
from abc import ABC
from enum import Enum
from typing import List, Optional, Tuple, Dict
from typing import Dict, List, Optional, Tuple

import torch
import torch_xla.core.xla_model as xm
from loguru import logger
from transformers import AutoTokenizer, PreTrainedTokenizerBase, StaticCache
from transformers.generation import GenerationConfig

from optimum.tpu import AutoModelForCausalLM
from optimum.tpu.generation import TokenSelector

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pkg_resources import parse_version

__version__ = "0.1.0.dev0"

__version__ = "0.1.0a0"
VERSION = parse_version(__version__)
8 changes: 5 additions & 3 deletions text-generation-inference/tests/test_gemma.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import pytest
import os
from tqdm import tqdm

import pytest
from text_generation_server.generator import TpuGenerator
from optimum.tpu.model import fetch_model
from text_generation_server.pb.generate_pb2 import (
Batch,
NextTokenChooserParameters,
Request,
StoppingCriteriaParameters,
)
from tqdm import tqdm

from optimum.tpu.model import fetch_model


MODEL_ID = "google/gemma-2b"
Expand Down
8 changes: 5 additions & 3 deletions text-generation-inference/tests/test_gpt2.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import pytest
import os
from tqdm import tqdm

import pytest
from text_generation_server.generator import TpuGenerator
from optimum.tpu.model import fetch_model
from text_generation_server.pb.generate_pb2 import (
Batch,
NextTokenChooserParameters,
Request,
StoppingCriteriaParameters,
)
from tqdm import tqdm

from optimum.tpu.model import fetch_model


MODEL_ID = "openai-community/gpt2"
Expand Down
Loading