Skip to content

Commit

Permalink
[FEATS][ InputOpenAISpec,
Browse files Browse the repository at this point in the history
    OutputOpenAISpec,
    OpenAIAPIWrapper,]
  • Loading branch information
Kye committed Feb 10, 2024
1 parent 4220ae4 commit d1c282d
Show file tree
Hide file tree
Showing 6 changed files with 921 additions and 1 deletion.
7 changes: 6 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,9 @@ skypilot
fastapi
supabase
pytest
pytest-benchmark
pytest-benchmark
tensorrt
torch
einops
tiktoken
uvicorn
141 changes: 141 additions & 0 deletions servers/blip2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import argparse
import os

import torch
import tensorrt as trt

# isort: on
import tensorrt_llm


def get_engine_name(rank):
return "rank{}.engine".format(rank)


def trt_dtype_to_torch(dtype):
if dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
return torch.float32
elif dtype == trt.int32:
return torch.int32
else:
raise TypeError("%s is not supported" % dtype)


def TRTOPT(args, config):
dtype = config["pretrained_config"]["dtype"]
world_size = config["pretrained_config"]["mapping"]["world_size"]
assert (
world_size == tensorrt_llm.mpi_world_size()
), f"Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})"

use_gpt_attention_plugin = bool(
config["build_config"]["plugin_config"]["gpt_attention_plugin"]
)

num_heads = config["pretrained_config"]["num_attention_heads"] // world_size
hidden_size = config["pretrained_config"]["hidden_size"] // world_size
vocab_size = config["pretrained_config"]["vocab_size"]
max_batch_size = config["build_config"]["max_batch_size"]
num_layers = config["pretrained_config"]["num_hidden_layers"]
remove_input_padding = config["build_config"]["plugin_config"][
"remove_input_padding"
]
max_prompt_embedding_table_size = config["build_config"].get(
"max_prompt_embedding_table_size", 0
)

model_config = tensorrt_llm.runtime.ModelConfig(
max_batch_size=max_batch_size,
vocab_size=vocab_size,
num_layers=num_layers,
num_heads=num_heads,
num_kv_heads=num_heads,
hidden_size=hidden_size,
gpt_attention_plugin=use_gpt_attention_plugin,
remove_input_padding=remove_input_padding,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
dtype=dtype,
)

runtime_rank = tensorrt_llm.mpi_rank()
runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank)
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)

engine_name = get_engine_name(runtime_rank)
serialize_path = os.path.join(args.opt_engine_dir, engine_name)

tensorrt_llm.logger.set_level(args.log_level)

with open(serialize_path, "rb") as f:
engine_buffer = f.read()
decoder = tensorrt_llm.runtime.GenerationSession(
model_config, engine_buffer, runtime_mapping
)

max_input_len = config["build_config"]["max_input_len"]
return decoder, model_config, world_size, dtype, max_input_len


def ptuning_setup(
prompt_table,
dtype,
hidden_size,
tasks,
input_ids,
input_lengths,
remove_input_padding,
):
if prompt_table is not None:
task_vocab_size = torch.tensor(
[prompt_table.shape[1]], dtype=torch.int32, device="cuda"
)
prompt_table = prompt_table.view(
(prompt_table.shape[0] * prompt_table.shape[1], prompt_table.shape[2])
)
prompt_table = prompt_table.cuda().to(
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)
)
else:
prompt_table = torch.empty([1, hidden_size]).cuda()
task_vocab_size = torch.zeros([1]).cuda()

num_sequences = input_lengths.size(0) if remove_input_padding else input_ids.size(0)

if tasks is not None:
tasks = torch.tensor(
[int(t) for t in tasks.split(",")], dtype=torch.int32, device="cuda"
)
assert (
tasks.shape[0] == num_sequences
), "Number of supplied tasks must match input batch size"
else:
tasks = torch.zeros([num_sequences], dtype=torch.int32).cuda()

return [prompt_table, tasks, task_vocab_size]


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--max_output_len", type=int, default=30)
parser.add_argument("--log_level", type=str, default="info")
parser.add_argument("--engine_dir", type=str, default="./plan")
parser.add_argument("--input_dir", type=str, default="image.pt")
parser.add_argument("--query_tokens", type=str, default="query_tokens.pt")
parser.add_argument(
"--opt_engine_dir", type=str, default="trt_engine/blip-2-opt-2.7b/fp16/1-gpu/"
)
parser.add_argument("--hf_model_location", type=str, default="facebook/opt-2.7b")
parser.add_argument(
"--input_text", type=str, default="Question: which city is this? Answer:"
)
parser.add_argument(
"--num_beams", type=int, help="Use beam search if num_beams >1", default=1
)
parser.add_argument(
"--max_txt_len", type=int, help="Max text prompt length", default=32
)
parser.add_argument("--top_k", type=int, default=1)

return parser.parse_args()
107 changes: 107 additions & 0 deletions servers/fuyu_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import argparse
import asyncio
import json
from typing import AsyncGenerator

import uvicorn
from executor import GenerationExecutor
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from swarms import Fuyu, Conversation

TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI()
executor: GenerationExecutor | None = None


@app.get("/stats")
async def stats() -> Response:
assert executor is not None
return JSONResponse(json.loads(await executor.aget_stats()))


@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)


@app.post("/generate")
async def generate(request: Request) -> Response:
assert executor is not None
"""Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()

streaming = request_dict.pop("streaming", False)

model_name = request.query_params.get("model_name")
max_new_tokens = request.query_params.get("max_new_tokens")

model = Fuyu(
model_name=model_name,
max_new_tokens=max_new_tokens,
args=args # Injecting args into the Fuyu model
)
response = model.run(
request_dict.pop("prompt"),
request_dict.pop("max_num_tokens", 8),
streaming
)

async def stream_results() -> AsyncGenerator[bytes, None]:
async for output in response:
yield (json.dumps({"text": output.text}) + "\n").encode("utf-8")

if streaming:
return StreamingResponse(stream_results(), media_type="text/plain")

# Non-streaming case
await response.await_completion()

# Return model configurations as JSON
model_config = {
"model_name": model.model_name,
"max_new_tokens": model.max_new_tokens,
"args": {
"model_dir": args.model_dir,
"tokenizer_type": args.tokenizer_type,
"max_beam_width": args.max_beam_width
}
}

return JSONResponse({"model_config": model_config, "choices": [{"text": response.text}]})


async def main(args):
global executor

executor = GenerationExecutor(
args.model_dir, args.tokenizer_type, args.max_beam_width
)
config = uvicorn.Config(
app,
host=args.host,
port=args.port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
)
await uvicorn.Server(config).serve()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("model_dir")
parser.add_argument("tokenizer_type")
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--max_beam_width", type=int, default=1)
args = parser.parse_args()

asyncio.run(main(args))
Loading

0 comments on commit d1c282d

Please sign in to comment.