Skip to content

Commit

Permalink
[python] Update lmi-dist (#975)
Browse files Browse the repository at this point in the history
* [python] Upgrade the dependency for lmi-dist

* [python] Upgrade lmi-dist
  • Loading branch information
xyang16 authored Jul 27, 2023
1 parent e4f56a2 commit 52f3eff
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import torch

QUANTIZATION_SUPPORT_ALGO = ["bitsandbytes"]
QUANTIZATION_SUPPORT_ALGO = ["bitsandbytes", "gptq"]


class LmiDistRollingBatch(RollingBatch):
Expand All @@ -42,6 +42,7 @@ def __init__(self, model_id_or_path, device, properties, **kwargs):
self.properties = properties
self.batch_cls = None
self._init_model(kwargs, model_id_or_path)
self._warmup(**kwargs)
self.batch_id_counter = 0
self.cache: Batch = None

Expand All @@ -65,9 +66,38 @@ def _init_model(self, kwargs, model_id_or_path):
revision=revision,
sharded=sharded,
quantize=quantize,
dtype=dtype,
trust_remote_code=kwargs.get("trust_remote_code"))
self.batch_cls = self.model.batch_type

def _warmup(self, **kwargs):
parameters = NextTokenChooserParameters(temperature=0.9,
repetition_penalty=1.2,
top_k=10,
top_p=0.9,
typical_p=0.9,
do_sample=False,
seed=0)
stop_parameters = StoppingCriteriaParameters(stop_sequences=[],
max_new_tokens=2)

max_prefill_tokens = int(
self.properties.get(
"max_rolling_batch_prefill_tokens",
int(self.properties.get("max_rolling_batch_size", 4)) * 512))
requests = [
lmi_dist.utils.types.Request(id=0,
inputs='_test ' * max_prefill_tokens,
parameters=parameters,
stopping_parameters=stop_parameters,
truncate=max_prefill_tokens)
]
batch = self.batch_cls.get_batch(
Batch(id=0, requests=requests,
size=len(requests)), self.model.tokenizer,
kwargs.get("torch_dtype", torch.float16), self.device)
self.model.warmup(batch, max_prefill_tokens)

@stop_on_any_exception
def inference(self, input_data, parameters):
"""
Expand Down
4 changes: 3 additions & 1 deletion serving/docker/deepspeed.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ARG python_version=3.9
ARG torch_version=2.0.1
ARG torch_vision_version=0.15.2
ARG deepspeed_wheel="https://publish.djl.ai/deepspeed/deepspeed-nightly-py2.py3-none-any.whl"
ARG vllm_wheel="https://publish.djl.ai/vllm/vllm-0.0.0-cp39-cp39-linux_x86_64.whl"
ARG lmi_dist_wheel="https://publish.djl.ai/lmi_dist/lmi_dist-nightly-py3-none-any.whl"
ARG protobuf_version=3.20.3
ARG transformers_version=4.30.2
Expand Down Expand Up @@ -60,10 +61,11 @@ RUN apt-get update && \
scripts/install_s5cmd.sh x64 && \
DEBIAN_FRONTEND=noninteractive apt-get install -yq libaio-dev libopenmpi-dev && \
pip3 install torch==${torch_version} torchvision==${torch_vision_version} --extra-index-url https://download.pytorch.org/whl/cu118 \
${deepspeed_wheel} ${lmi_dist_wheel} protobuf==${protobuf_version} transformers==${transformers_version} \
${deepspeed_wheel} ${vllm_wheel} ${lmi_dist_wheel} protobuf==${protobuf_version} transformers==${transformers_version} \
mpi4py sentencepiece einops accelerate==${accelerate_version} bitsandbytes==${bitsandbytes_version}\
diffusers[torch]==${diffusers_version} peft==${peft_version} opencv-contrib-python-headless safetensors scipy && \
scripts/install_flash_attn.sh && \
scripts/install_flash_attn_v2.sh && \
scripts/install_aitemplate.sh && \
scripts/patch_oss_dlc.sh python && \
scripts/security_patch.sh deepspeed && \
Expand Down
2 changes: 1 addition & 1 deletion serving/docker/scripts/install_flash_attn.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

rm -rf flash-attention
git clone https://github.com/HazyResearch/flash-attention.git -b v1.0.0
git clone https://github.com/Dao-AILab/flash-attention.git -b v1.0.9
pushd flash-attention || exit 1
pip install -v .
cd csrc/layer_norm && pip install -v .
Expand Down
8 changes: 8 additions & 0 deletions serving/docker/scripts/install_flash_attn_v2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash

rm -rf flash-attention-v2
git clone https://github.com/Dao-AILab/flash-attention.git flash-attention-v2 -b v2.0.0
pushd flash-attention-v2 || exit 1
pip install -v .
popd || exit 1
rm -rf flash-attention-v2

0 comments on commit 52f3eff

Please sign in to comment.