diff --git a/.github/workflows/build_rocm.yaml b/.github/workflows/build_rocm.yaml new file mode 100644 index 00000000..8ecfa1e5 --- /dev/null +++ b/.github/workflows/build_rocm.yaml @@ -0,0 +1,144 @@ + name: Build and push AMD ROCm docker image to registry + + on: + workflow_dispatch: + push: + branches: + - 'main' + tags: + - 'v*' + pull_request: + paths: + - ".github/workflows/build.yaml" +# - "integration-tests/**" + - "backends/**" + - "core/**" + - "router/**" + - "Cargo.lock" + - "rust-toolchain.toml" + - "Dockerfile" + branches: + - 'main' + + jobs: + build-and-push-image: + concurrency: + group: ${{ github.workflow }}-${{ github.job }}-rocm-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + runs-on: [self-hosted, intel-cpu, 32-cpu, 256-ram, ci] + permissions: + contents: write + packages: write + # This is used to complete the identity challenge + # with sigstore/fulcio when running outside of PRs. + id-token: write + security-events: write + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Tailscale + uses: huggingface/tailscale-action@v1 + with: + authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + + - name: Initialize Docker Buildx + uses: docker/setup-buildx-action@v2.0.0 + with: + install: true + config-inline: | + [registry."docker.io"] + mirrors = ["registry.github-runners.huggingface.tech"] + + - name: Configure sccache + uses: actions/github-script@v6 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: Inject slug/short variables + uses: rlespinasse/github-slug-action@v4.4.1 + + - name: Login to GitHub Container Registry + if: github.event_name != 'pull_request' + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Login to internal Container Registry + uses: docker/login-action@v2.1.0 + with: + username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }} + password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} + registry: registry.internal.huggingface.tech + + - name: Extract metadata (tags, labels) for Docker + id: meta-rocm + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=rocm-{{version}} + type=semver,pattern=rocm-{{major}}.{{minor}} + type=raw,value=rocm-latest + type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }} + + - name: Build and push Docker image + id: build-and-push-rocm + uses: docker/build-push-action@v4 + with: + context: . + file: Dockerfile-rocm + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-rocm.outputs.tags }} + labels: ${{ steps.meta-rocm.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max + cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max + + - name: Extract metadata (tags, labels) for Docker + id: meta-rocm-grpc + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=rocm-{{version}}-grpc + type=semver,pattern=rocm-{{major}}.{{minor}}-grpc + type=raw,value=rocm-latest-grpc + type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }}-grpc + + - name: Build and push Docker image + id: build-and-push-rocm-grpc + uses: docker/build-push-action@v4 + with: + context: . + target: grpc + file: Dockerfile-rocm + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-rocm-grpc.outputs.tags }} + labels: ${{ steps.meta-rocm-grpc.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max diff --git a/.gitignore b/.gitignore index ee44a963..6862c2f1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .idea target +__pycache__/ diff --git a/Dockerfile b/Dockerfile index 8c1368c0..6d7c35d4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,7 +4,7 @@ WORKDIR /usr/src ENV SCCACHE=0.5.4 ENV RUSTC_WRAPPER=/usr/local/bin/sccache -# Donwload, configure sccache +# Download, configure sccache RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ chmod +x /usr/local/bin/sccache diff --git a/Dockerfile-rocm b/Dockerfile-rocm new file mode 100644 index 00000000..4ac343a4 --- /dev/null +++ b/Dockerfile-rocm @@ -0,0 +1,136 @@ +FROM rocm/dev-ubuntu-22.04:6.0.2 AS base-builder + +ENV SCCACHE=0.5.4 +ENV RUSTC_WRAPPER=/usr/local/bin/sccache +ENV PATH="/root/.cargo/bin:${PATH}" + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + curl \ + libssl-dev \ + pkg-config \ + && rm -rf /var/lib/apt/lists/* + +# Donwload and configure sccache +RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ + chmod +x /usr/local/bin/sccache + +RUN curl https://sh.rustup.rs -sSf | bash -s -- -y +RUN cargo install cargo-chef --locked + +FROM base-builder AS planner + +WORKDIR /usr/src + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN cargo chef prepare --recipe-path recipe.json + +FROM base-builder AS builder + +ARG CUDA_COMPUTE_CAP=80 +ARG GIT_SHA +ARG DOCKER_LABEL + +# sccache specific variables +ARG ACTIONS_CACHE_URL +ARG ACTIONS_RUNTIME_TOKEN +ARG SCCACHE_GHA_ENABLED + +WORKDIR /usr/src + +COPY --from=planner /usr/src/recipe.json recipe.json + +RUN cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + unzip \ + && rm -rf /var/lib/apt/lists/* + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY proto proto + +FROM builder as http-builder + +RUN cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s + +FROM builder as grpc-builder + +RUN cargo build --release --bin text-embeddings-router -F python -F grpc --no-default-features && sccache -s + +FROM rocm/dev-ubuntu-22.04:6.0.2 as base + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + git \ + python3-dev \ + rocthrust-dev \ + hipsparse-dev \ + hipblas-dev \ + hipblaslt-dev \ + hiprand-dev \ + hipsolver-dev \ + rocblas-dev \ + rocrand-dev \ + && rm -rf /var/lib/apt/lists/* + + +# Keep in sync with `server/pyproject.toml +ARG MAMBA_VERSION=23.1.0-1 +ARG PYTORCH_VERSION='2.3.0' +ARG ROCM_VERSION='6.0.2' +ARG PYTHON_VERSION='3.10.10' +# Automatically set by buildx +ARG TARGETPLATFORM +ENV PATH /opt/conda/bin:$PATH + +RUN curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-x86_64.sh" +RUN chmod +x ~/mambaforge.sh && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + mamba init && \ + rm ~/mambaforge.sh + +# Install flash-attention, torch dependencies +RUN pip install numpy einops ninja --no-cache-dir + +# Install python backend +COPY backends/python/server /tei_backends/python/server +COPY backends/proto tei_backends/proto +RUN make -C /tei_backends/python/server install + +RUN pip install --force-reinstall torch==$PYTORCH_VERSION --index-url https://download.pytorch.org/whl/rocm6.0 + +ARG DEFAULT_USE_FLASH_ATTENTION=True +COPY backends/python/Makefile-flash-att-v2 Makefile-flash-att-v2 +RUN make -f Makefile-flash-att-v2 install-flash-attention-v2-rocm + +ENV HUGGINGFACE_HUB_CACHE=/data \ + PORT=80 \ + USE_FLASH_ATTENTION=$DEFAULT_USE_FLASH_ATTENTION + +FROM base as grpc + +COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] + +FROM base + +COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] diff --git a/backends/python/Makefile-flash-att-v2 b/backends/python/Makefile-flash-att-v2 new file mode 100644 index 00000000..ba90a74d --- /dev/null +++ b/backends/python/Makefile-flash-att-v2 @@ -0,0 +1,21 @@ +flash_att_v2_commit_cuda := v2.5.9.post1 +flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 + +build-flash-attention-v2-cuda: + pip install -U packaging wheel + pip install flash-attn==$(flash_att_v2_commit_cuda) + +install-flash-attention-v2-cuda: build-flash-attention-v2-cuda + echo "Flash v2 installed" + +build-flash-attention-v2-rocm: + if [ ! -d 'flash-attention-v2' ]; then \ + pip install -U packaging ninja --no-cache-dir && \ + git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \ + cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \ + git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \ + fi + +install-flash-attention-v2-rocm: build-flash-attention-v2-rocm + cd flash-attention-v2 && \ + GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install diff --git a/backends/python/server/pyproject.toml b/backends/python/server/pyproject.toml index 96fcaf9e..8fbc0008 100644 --- a/backends/python/server/pyproject.toml +++ b/backends/python/server/pyproject.toml @@ -15,12 +15,13 @@ grpcio-status = "^1.51.1" grpcio-reflection = "^1.51.1" grpc-interceptor = "^0.15.0" typer = "^0.6.1" -safetensors = "^0.3.2" +safetensors = "^0.4.0" loguru = "^0.6.0" opentelemetry-api = "^1.15.0" opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-instrumentation-grpc = "^0.36b0" -torch = { version = "^2.0.1" } +torch = { version = "==2.3.1" } +transformers = { version = "^4.39.0"} [tool.poetry.extras] @@ -33,6 +34,11 @@ name = "pytorch-gpu-src" url = "https://download.pytorch.org/whl/cu118" priority = "explicit" +[[tool.poetry.source]] +name = "pytorch-gpu-src-rocm" +url = "https://download.pytorch.org/whl/rocm6.0" +priority = "explicit" + [tool.pytest.ini_options] markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] diff --git a/backends/python/server/requirements.txt b/backends/python/server/requirements.txt index 89ca314d..2d089e41 100644 --- a/backends/python/server/requirements.txt +++ b/backends/python/server/requirements.txt @@ -4,20 +4,13 @@ charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" -filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13" -fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13" idna==3.4 ; python_version >= "3.9" and python_version < "3.13" -jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" -markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13" -mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" -networkx==3.1 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" @@ -27,15 +20,10 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -packaging==23.1 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13" -pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13" setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13" -sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" -torch==2.0.1 ; python_version >= "3.9" and python_version < "3.13" -tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/python/server/text_embeddings_server/cli.py b/backends/python/server/text_embeddings_server/cli.py index 9497dc20..4f423afe 100644 --- a/backends/python/server/text_embeddings_server/cli.py +++ b/backends/python/server/text_embeddings_server/cli.py @@ -24,6 +24,7 @@ def serve( json_output: bool = False, otlp_endpoint: Optional[str] = None, otlp_service_name: str = "text-embeddings-inference.server", + pooling_mode: Optional[str] = None, ): # Remove default handler logger.remove() @@ -48,7 +49,7 @@ def serve( # Downgrade enum into str for easier management later on dtype = None if dtype is None else dtype.value - server.serve(model_path, dtype, uds_path) + server.serve(model_path, dtype, uds_path, pooling_mode) if __name__ == "__main__": diff --git a/backends/python/server/text_embeddings_server/layers/__init__.py b/backends/python/server/text_embeddings_server/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backends/python/server/text_embeddings_server/layers/attention/__init__.py b/backends/python/server/text_embeddings_server/layers/attention/__init__.py new file mode 100644 index 00000000..42aac2bd --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/attention/__init__.py @@ -0,0 +1,14 @@ +from text_embeddings_server.utils.import_utils import SYSTEM +import os + +if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": + class Attention: + def __getattr__(self, name): + raise RuntimeError(f"TEI is used with USE_FLASH_ATTENTION=false, accessing `attention` is prohibited") + attention = Attention() +if SYSTEM == "cuda": + from .cuda import attention +elif SYSTEM == "rocm": + from .rocm import attention +else: + raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") diff --git a/backends/python/server/text_embeddings_server/utils/flash_attn.py b/backends/python/server/text_embeddings_server/layers/attention/cuda.py similarity index 100% rename from backends/python/server/text_embeddings_server/utils/flash_attn.py rename to backends/python/server/text_embeddings_server/layers/attention/cuda.py diff --git a/backends/python/server/text_embeddings_server/layers/attention/rocm.py b/backends/python/server/text_embeddings_server/layers/attention/rocm.py new file mode 100644 index 00000000..9ed9004c --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/attention/rocm.py @@ -0,0 +1,45 @@ +import os +import torch +from text_embeddings_server.utils.import_utils import SYSTEM +from loguru import logger + +major, minor = torch.cuda.get_device_capability() +is_sm75 = major == 7 and minor == 5 + +if SYSTEM == "rocm": + try: + import flash_attn_2_cuda + + logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") + except ImportError as e: + if major >= 8 or is_sm75: + architecture_suffix = f"-{SYSTEM}" + raise ImportError(f"Flash Attention V2 is not installed. {e}") + else: + for idx in range(torch.cuda.device_count()): + name = torch.cuda.get_device_name(idx) + if "MI210" not in name and "MI250" not in name and "MI300" not in name: + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) + raise ImportError( + f"AMD GPU with ROCm capability {major} {minor} is not supported" + ) from e + +def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False): + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + is_causal, + False, + None, + ) diff --git a/backends/python/server/text_embeddings_server/layers/layernorm.py b/backends/python/server/text_embeddings_server/layers/layernorm.py new file mode 100644 index 00000000..0834b734 --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/layernorm.py @@ -0,0 +1,54 @@ +import torch +from text_embeddings_server.utils.import_utils import SYSTEM + +from transformers.models.bert import BertConfig + +if SYSTEM == "cuda": + import dropout_layer_norm + + class FastLayerNorm: + def __init__(self, prefix, handle, device, dtype, config: BertConfig): + self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device) + self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device) + self.variance_epsilon = config.layer_norm_eps + + def forward(self, hidden_states, residual=None): + normed_hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + + return normed_hidden_states, residual + +elif SYSTEM == "rocm": + class FastLayerNorm: + def __init__(self, prefix, handle, device, dtype, config: BertConfig): + self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device) + self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device) + self.variance_epsilon = config.layer_norm_eps + + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states + + hidden_states = torch.nn.functional.layer_norm(hidden_states, self.weight.shape, self.weight, self.bias, eps=self.variance_epsilon) + + return hidden_states, residual +else: + raise ValueError("System not recognized") diff --git a/backends/python/server/text_embeddings_server/layers/pooling.py b/backends/python/server/text_embeddings_server/layers/pooling.py new file mode 100644 index 00000000..7eaddb6b --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/pooling.py @@ -0,0 +1,22 @@ +import torch +from flash_attn.bert_padding import pad_input + +from loguru import logger + +def mean_pooling(embedding, cu_seqlens, max_s): + # Ideally, rust would pass `indices` to the FlashBatch. + seqlens = cu_seqlens[1:].clone() + seqlens[0] = cu_seqlens[1] + seqlens[1:] -= cu_seqlens[1:-1] + batch_size = len(seqlens) + + # Example: indices = [0, 1, 2, 3, 7, 8, 9, 10, 11, 12, 13] + mask = torch.zeros(batch_size, max_s, dtype=torch.int32, device=cu_seqlens.device) + mask[torch.arange(max_s) < seqlens[:, None].cpu()] = 1 + indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() + + embedding_padded = pad_input(embedding, indices, batch_size, max_s) + + sum_embeddings = torch.sum(embedding_padded, 1) + + return sum_embeddings / seqlens[:, None] diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 47867187..fbd2348b 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -25,7 +25,7 @@ __all__.append(FlashBert) -def get_model(model_path: Path, dtype: Optional[str]): +def get_model(model_path: Path, dtype: Optional[str], pooling_mode: str): if dtype == "float32": dtype = torch.float32 elif dtype == "float16": @@ -52,8 +52,8 @@ def get_model(model_path: Path, dtype: Optional[str]): and dtype in [torch.float16, torch.bfloat16] and FLASH_ATTENTION ): - return FlashBert(model_path, device, dtype) + return FlashBert(model_path, device, dtype, pooling_mode) else: - return DefaultModel(model_path, device, dtype) + return DefaultModel(model_path, device, dtype, pooling_mode) raise NotImplementedError diff --git a/backends/python/server/text_embeddings_server/models/default_model.py b/backends/python/server/text_embeddings_server/models/default_model.py index dc39fdc8..3c41b6c3 100644 --- a/backends/python/server/text_embeddings_server/models/default_model.py +++ b/backends/python/server/text_embeddings_server/models/default_model.py @@ -8,14 +8,16 @@ from text_embeddings_server.models import Model from text_embeddings_server.models.types import PaddedBatch, Embedding +from typing import Optional tracer = trace.get_tracer(__name__) class DefaultModel(Model): - def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): + def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype, pooling_mode: Optional[str]): model = AutoModel.from_pretrained(model_path).to(dtype).to(device) self.hidden_size = model.config.hidden_size + self.pooling_mode = pooling_mode self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) @@ -41,7 +43,14 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]: kwargs["position_ids"] = batch.position_ids output = self.model(**kwargs) - embedding = output[0][:, 0] + + if self.pooling_mode == "cls": + embedding = output[0][:, 0] + elif self.pooling_mode == "mean": + embedding = output[0].mean(dim=1) + else: + raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend") + cpu_results = embedding.view(-1).tolist() return [ diff --git a/backends/python/server/text_embeddings_server/models/flash_bert.py b/backends/python/server/text_embeddings_server/models/flash_bert.py index 50b8d70d..40003013 100644 --- a/backends/python/server/text_embeddings_server/models/flash_bert.py +++ b/backends/python/server/text_embeddings_server/models/flash_bert.py @@ -8,46 +8,16 @@ from transformers.models.bert import BertConfig from opentelemetry import trace -# Flash attention imports -import dropout_layer_norm - from text_embeddings_server.models import Model from text_embeddings_server.models.types import FlashBatch, Embedding -from text_embeddings_server.utils.flash_attn import attention +from text_embeddings_server.layers.attention import attention +from text_embeddings_server.layers.layernorm import FastLayerNorm +from text_embeddings_server.layers.pooling import mean_pooling +from typing import Optional tracer = trace.get_tracer(__name__) -class FastLayerNorm: - def __init__(self, prefix, handle, device, dtype, config: BertConfig): - self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device) - self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device) - self.variance_epsilon = config.layer_norm_eps - - def forward(self, hidden_states, residual=None): - normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - self.bias, - None, - None, - None, - None, - 0.0, - self.variance_epsilon, - 1.0, - 0, - None, - False, - False, - ) - if res is None: - res = hidden_states - - return normed_hidden_states, res - - class BertEmbeddings: def __init__(self, prefix, handle, device, dtype, config: BertConfig): self.word_embeddings_weight = ( @@ -217,16 +187,17 @@ def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s): embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids) encoder_outputs = self.encoder.forward(embeddings, cu_seqlens, max_s) - return encoder_outputs[cu_seqlens[:-1]] + return encoder_outputs class FlashBert(Model): - def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): + def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype, pooling_mode: Optional[str]): config = BertConfig.from_pretrained(model_path) with safe_open(model_path / "model.safetensors", framework="pt") as f: model = FlashBertModel(f, device, dtype, config) self.hidden_size = config.hidden_size + self.pooling_mode = pooling_mode super(FlashBert, self).__init__(model=model, dtype=dtype, device=device) @@ -243,11 +214,24 @@ def embed(self, batch: FlashBatch) -> List[Embedding]: cu_seqlens=batch.cu_seqlens, max_s=batch.max_s, ) - cpu_results = embedding.view(-1).tolist() - return [ - Embedding( - values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size] - ) - for i in range(len(batch)) - ] + if self.pooling_mode == "cls": + embedding = embedding[batch.cu_seqlens[:-1]] + cpu_results = embedding.view(-1).tolist() + + return [ + Embedding( + values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size] + ) + for i in range(len(batch)) + ] + elif self.pooling_mode == "mean": + res = mean_pooling(embedding, batch.cu_seqlens, batch.max_s) + return [ + Embedding( + values=res[i] + ) + for i in range(len(batch)) + ] + else: + raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend") diff --git a/backends/python/server/text_embeddings_server/server.py b/backends/python/server/text_embeddings_server/server.py index d0a43ace..2c99cf79 100644 --- a/backends/python/server/text_embeddings_server/server.py +++ b/backends/python/server/text_embeddings_server/server.py @@ -37,6 +37,7 @@ def serve( model_path: Path, dtype: Optional[str], uds_path: Path, + pooling_mode: Optional[str], ): async def serve_inner( model_path: Path, @@ -45,7 +46,7 @@ async def serve_inner( unix_socket = f"unix://{uds_path}" try: - model = get_model(model_path, dtype) + model = get_model(model_path, dtype, pooling_mode) except Exception: logger.exception("Error when initializing model") raise diff --git a/backends/python/server/text_embeddings_server/utils/import_utils.py b/backends/python/server/text_embeddings_server/utils/import_utils.py new file mode 100644 index 00000000..83394eaa --- /dev/null +++ b/backends/python/server/text_embeddings_server/utils/import_utils.py @@ -0,0 +1,12 @@ +import torch +from loguru import logger + +SYSTEM = None +if torch.version.hip is not None: + SYSTEM = "rocm" +elif torch.version.cuda is not None and torch.cuda.is_available(): + SYSTEM = "cuda" +else: + SYSTEM = "cpu" + +logger.info(f"Python backend: detected system {SYSTEM}") diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index 195f1d37..142547bc 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -24,6 +24,8 @@ impl PythonBackend { otlp_endpoint: Option, otlp_service_name: String, ) -> Result { + let model_type_clone = model_type.clone(); + match model_type { ModelType::Classifier => { return Err(BackendError::Start( @@ -31,19 +33,26 @@ impl PythonBackend { )) } ModelType::Embedding(pool) => { - if pool != Pool::Cls { - return Err(BackendError::Start(format!("{pool:?} is not supported"))); + if pool != Pool::Cls && pool != Pool::Mean { + return Err(BackendError::Start(format!("{pool:?} is not supported in the TEI Python backend. Please open an issue."))); } pool } }; + let pool_string = match &model_type_clone { + ModelType::Classifier => &Pool::Cls, + ModelType::Embedding(pool) => pool, + } + .to_string(); + let backend_process = management::BackendProcess::new( model_path, dtype, &uds_path, otlp_endpoint, otlp_service_name, + pool_string, )?; let tokio_runtime = tokio::runtime::Builder::new_current_thread() .enable_all() diff --git a/backends/python/src/management.rs b/backends/python/src/management.rs index 911c6984..2044a3e0 100644 --- a/backends/python/src/management.rs +++ b/backends/python/src/management.rs @@ -22,6 +22,7 @@ impl BackendProcess { uds_path: &str, otlp_endpoint: Option, otlp_service_name: String, + pooling_mode: String, ) -> Result { // Get UDS path let uds = Path::new(uds_path); @@ -52,6 +53,9 @@ impl BackendProcess { python_server_args.push("--otlp-service-name".to_owned()); python_server_args.push(otlp_service_name); + python_server_args.push("--pooling-mode".to_owned()); + python_server_args.push(pooling_mode); + // Copy current process env let envs: Vec<(OsString, OsString)> = env::vars_os().collect(); diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 211d1ca5..b8fbb6f9 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -11,8 +11,10 @@ title: Using TEI locally with CPU - local: local_metal title: Using TEI locally with Metal - - local: local_gpu - title: Using TEI locally with GPU + - local: local_nvidia_gpu + title: Using TEI locally with Nvidia GPU + - local: local_amd_gpu + title: Using TEI locally with AMD GPU - local: private_models title: Serving private and gated models # - local: tei_cli diff --git a/docs/source/en/local_amd_gpu.md b/docs/source/en/local_amd_gpu.md new file mode 100644 index 00000000..7a2bfb73 --- /dev/null +++ b/docs/source/en/local_amd_gpu.md @@ -0,0 +1,41 @@ + + +# Using TEI locally with an AMD GPU + +Text-Embeddings-Inference supports the [AMD GPUs officially supporting ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html), including AMD Instinct MI210, MI250, MI300 and some of the AMD Radeon series GPUs. + +To leverage AMD GPUs, Text-Embeddings-Inference relies on its Python backend, and not on the [candle](https://github.com/huggingface/candle) backend that is used for CPU, Nvidia GPUs and Metal. The support in the python backend is more limited (Bert embeddings) but easily extensible. We welcome contributions to extend the supported models. + +## Usage through docker + +Using docker is the recommended approach. + +```bash +DOCKER_TAG=rocm-xxx # Specify the tag of the docker image to use +docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --net host \ + --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 32g \ + ghcr.io/huggingface/text-embeddings-inference:$DOCKER_TAG \ + --model-id sentence-transformers/all-MiniLM-L6-v2 +``` + +and + +```bash +curl 127.0.0.1:80/embed \ + -X POST -d '{"inputs":"What is Deep Learning?"}' \ + -H 'Content-Type: application/json' +``` diff --git a/docs/source/en/local_gpu.md b/docs/source/en/local_nvidia_gpu.md similarity index 96% rename from docs/source/en/local_gpu.md rename to docs/source/en/local_nvidia_gpu.md index 7b76300a..f2e71cfd 100644 --- a/docs/source/en/local_gpu.md +++ b/docs/source/en/local_nvidia_gpu.md @@ -14,9 +14,9 @@ rendered properly in your Markdown viewer. --> -# Using TEI locally with GPU +# Using TEI locally with Nvidia GPU -You can install `text-embeddings-inference` locally to run it on your own machine with a GPU. +You can install `text-embeddings-inference` locally to run it on your own machine with an Nvidia GPU. To make sure that your hardware is supported, check out the [Supported models and hardware](supported_models) page. ## Step 1: CUDA and NVIDIA drivers diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000..cfbf805c --- /dev/null +++ b/tests/README.md @@ -0,0 +1,34 @@ +## Testing + +To run the tests, install from within docker with `--entrypoint "/bin/bash"` the requirements +``` +pip install -r requirements.txt +``` + +and mounting a volume for the tests, they can be run from within the container with +``` +pytest tests/ -s -vvvvv +``` + +## Reference outputs + +For example, collecting the reference on an RTX 4090 on Candle backend: +``` +docker run --rm -it --gpus all --net host --entrypoint "/bin/bash" -v $(pwd):/tei ghcr.io/huggingface/text-embeddings-inference:89-1.2.3 +``` +and +``` +text-embeddings-router --model-id sentence-transformers/all-MiniLM-L6-v2 +``` + +and then +``` +python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 1 --flash +python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 3 --flash +``` + +Restart server with `USE_FLASH_ATTENTION=0`, and +``` +python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 1 +python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 3 +``` diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt new file mode 100644 index 00000000..aaf95a92 Binary files /dev/null and b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt differ diff --git a/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt new file mode 100644 index 00000000..d986e332 Binary files /dev/null and b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt differ diff --git a/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3.pt b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3.pt new file mode 100644 index 00000000..bea6dca1 Binary files /dev/null and b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3.pt differ diff --git a/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3_no_flash.pt b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3_no_flash.pt new file mode 100644 index 00000000..7bf51879 Binary files /dev/null and b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3_no_flash.pt differ diff --git a/tests/collect.py b/tests/collect.py new file mode 100644 index 00000000..640f854c --- /dev/null +++ b/tests/collect.py @@ -0,0 +1,37 @@ + +import requests +import torch +import argparse +import json +import os + +parser = argparse.ArgumentParser(description='Assets collection') +parser.add_argument('--model-id', help='Model id', required=True) +parser.add_argument('--n_inp', help='Number of inputs', required=True, type=int) +parser.add_argument('--flash', action='store_true') + +args = parser.parse_args() + +url = f"http://0.0.0.0:80/embed" + +INPUTS = [ + "What is Deep Learning?", + "Today I am in Paris and I would like to", + "Paris weather is", + "Great job" +] + +data = {"inputs": INPUTS[:args.n_inp]} +headers = {"Content-Type": "application/json"} + +response = requests.post(url, json=data, headers=headers) + +embedding = torch.Tensor(json.loads(response.text)) + +postfix = "" +if not args.flash: + postfix = "_no_flash" + +save_path = f"./assets/{args.model_id.replace('/', '-')}_inp{args.n_inp}{postfix}.pt" +print(f"Saving embedding of shape {embedding.shape} to {save_path}") +torch.save(embedding, save_path) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..efdd6fc2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,113 @@ +import pytest +import asyncio +import contextlib +import random +import os +import tempfile +import subprocess +import shutil +import sys +from typing import Optional +from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError +import requests +import time +from requests.exceptions import ConnectionError as RequestsConnectionError + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + +class ProcessLauncherHandle: + def __init__(self, process, port: int): + self.port = port + self.process = process + + def _inner_health(self) -> bool: + return self.process.poll() is None + + def health(self, timeout: int = 60): + assert timeout > 0 + for _ in range(timeout): + if not self._inner_health(): + raise RuntimeError("Launcher crashed") + + try: + url = f"http://0.0.0.0:{self.port}/health" + headers = {"Content-Type": "application/json"} + + response = requests.post(url, headers=headers) + return + except (ClientConnectorError, ClientOSError, ServerDisconnectedError, RequestsConnectionError) as e: + print("Connecting") + time.sleep(1) + raise RuntimeError("Health check failed") + +@pytest.fixture(scope="module") +def launcher(event_loop): + @contextlib.contextmanager + def local_launcher( + model_id: str, + trust_remote_code: bool = False, + use_flash_attention: bool = True, + dtype: Optional[str] = None, + revision: Optional[str] = None, + pooling: Optional[str] = None, + ): + port = random.randint(8000, 10_000) + shard_uds_path = ( + f"/tmp/tei-tests-{model_id.split('/')[-1]}-server" + ) + + args = [ + "text-embeddings-router", + "--model-id", + model_id, + "--port", + str(port), + "--uds-path", + shard_uds_path, + ] + + env = os.environ + + if dtype is not None: + args.append("--dtype") + args.append(dtype) + if revision is not None: + args.append("--revision") + args.append(revision) + if trust_remote_code: + args.append("--trust-remote-code") + if pooling: + args.append("--pooling") + args.append(str(max_input_length)) + + env["LOG_LEVEL"] = "debug" + + if not use_flash_attention: + env["USE_FLASH_ATTENTION"] = "false" + + with tempfile.TemporaryFile("w+") as tmp: + # We'll output stdout/stderr to a temporary file. Using a pipe + # cause the process to block until stdout is read. + print("call subprocess.Popen, with args", args) + with subprocess.Popen( + args, + stdout=tmp, + stderr=subprocess.STDOUT, + env=env, + ) as process: + yield ProcessLauncherHandle(process, port) + + process.terminate() + process.wait(60) + + tmp.seek(0) + shutil.copyfileobj(tmp, sys.stderr) + + if not use_flash_attention: + del env["USE_FLASH_ATTENTION"] + + return local_launcher diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 00000000..2f4c80e3 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = auto diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 00000000..74d3b667 --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,3 @@ +pytest +pytest-asyncio +aiohttp diff --git a/tests/test_default_model.py b/tests/test_default_model.py new file mode 100644 index 00000000..68499928 --- /dev/null +++ b/tests/test_default_model.py @@ -0,0 +1,28 @@ +import pytest +import requests +import json +import torch + +@pytest.fixture(scope="module") +def default_model_handle(launcher): + with launcher("sentence-transformers/all-MiniLM-L6-v2", use_flash_attention=False) as handle: + yield handle + +@pytest.fixture(scope="module") +async def default_model(default_model_handle): + default_model_handle.health(300) + return default_model_handle + +@pytest.mark.asyncio +@pytest.mark.private +async def test_single_query(default_model): + url = f"http://0.0.0.0:{default_model.port}/embed" + data = {"inputs": "What is Deep Learning?"} + headers = {"Content-Type": "application/json"} + + response = requests.post(url, json=data, headers=headers) + + embedding = torch.Tensor(json.loads(response.text)) + reference_embedding = torch.load("./tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt") + + assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3) diff --git a/tests/test_flash_bert.py b/tests/test_flash_bert.py new file mode 100644 index 00000000..04085522 --- /dev/null +++ b/tests/test_flash_bert.py @@ -0,0 +1,28 @@ +import pytest +import requests +import json +import torch + +@pytest.fixture(scope="module") +def default_model_handle(launcher): + with launcher("sentence-transformers/all-MiniLM-L6-v2", use_flash_attention=True) as handle: + yield handle + +@pytest.fixture(scope="module") +async def default_model(default_model_handle): + default_model_handle.health(300) + return default_model_handle + +@pytest.mark.asyncio +@pytest.mark.private +async def test_single_query(default_model): + url = f"http://0.0.0.0:{default_model.port}/embed" + data = {"inputs": "What is Deep Learning?"} + headers = {"Content-Type": "application/json"} + + response = requests.post(url, json=data, headers=headers) + + embedding = torch.Tensor(json.loads(response.text)) + reference_embedding = torch.load("./tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt") + + assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3)