Skip to content

Commit

Permalink
feat: change deploy example to phi-3.5-vision
Browse files Browse the repository at this point in the history
  • Loading branch information
hommayushi3 committed Aug 25, 2024
1 parent fa84071 commit 8d88e9d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 35 deletions.
9 changes: 7 additions & 2 deletions examples/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@

if __name__ == "__main__":
load_dotenv()
repo_id = "microsoft/Phi-3-vision-128k-instruct"
repo_id = "microsoft/Phi-3.5-vision-instruct"
env_vars = {
"MAX_MODEL_LEN": "3072",
"DISABLE_SLIDING_WINDOW": "true",
"MAX_MODEL_LEN": "2048",
"MAX_NUM_BATCHED_TOKENS": "8192",
"DTYPE": "bfloat16",
"GPU_MEMORY_UTILIZATION": "0.98",
"QUANTIZATION": "fp8",
"USE_V2_BLOCK_MANAGER": "true",
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
"TRUST_REMOTE_CODE": "true",
}

Expand Down
78 changes: 45 additions & 33 deletions examples/inference.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,53 @@
from openai import OpenAI
import os
from dotenv import load_dotenv
from time import time
from multiprocessing.pool import ThreadPool


if __name__ == "__main__":
load_dotenv()
ENDPOINT_URL = os.getenv("HF_ENDPOINT_URL") + "/v1/" # if endpoint object is not available check the UI
API_KEY = os.getenv("HF_TOKEN")
STREAM = False
load_dotenv()
ENDPOINT_URL = os.getenv("HF_ENDPOINT_URL") + "/v1/" # if endpoint object is not available check the UI
API_KEY = os.getenv("HF_TOKEN")
STREAM = False

# initialize the client but point it to TGI
client = OpenAI(base_url=ENDPOINT_URL, api_key=API_KEY)

chat_completion = client.chat.completions.create(
# initialize the client but point it to TGI
client = OpenAI(base_url=ENDPOINT_URL, api_key=API_KEY)


def predict(messages):
return client.chat.completions.create(
model="/repository", # needs to be /repository since there are the model artifacts stored
messages=[
{"role": "user", "content": [
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/0/05/Facebook_Logo_%282019%29.png"
}
},
{
"type": "text",
"text": "What is in the above image?"
}
]},
],
max_tokens=500,
messages=messages,
max_tokens=30,
temperature=0.0,
stream=STREAM,
)

if STREAM:
for message in chat_completion:
if message.choices[0].delta.content:
print(message.choices[0].delta.content, end="")
else:
print(chat_completion.choices[0].message.content)
stream=False,
).choices[0].message.content

if __name__ == "__main__":

batch_size = 64
pool = ThreadPool(batch_size)

messages = [
{"role": "user", "content": [
{
"type": "image_url",
"image_url": {
"url": "https://unsplash.com/photos/ZVw3HmHRhv0/download?ixid=M3wxMjA3fDB8MXxhbGx8NHx8fHx8fDJ8fDE3MjQ1NjAzNjl8&force=true&w=1920"
}
},
{
"type": "text",
"text": "What is in the above image? Explain in detail."
}
]},
]

start = time()
responses = pool.map(predict, [messages] * batch_size)
print(responses[0])
print(f"Time taken: {time() - start:.2f}s")

start = time()
responses = list(map(predict, [messages] * batch_size))
print(f"Time taken: {time() - start:.2f}s")

0 comments on commit 8d88e9d

Please sign in to comment.