Skip to content

Commit

Permalink
Optimize api server
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Dec 20, 2023
1 parent 06a35ae commit 95d90c8
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 21 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ dependencies = [
"wandb",
"tensorboard",
"grpcio>=1.58.0",
"kui>=1.6.0"
"kui>=1.6.0",
"zibai-server>=0.9.0"
]

[build-system]
Expand Down
56 changes: 36 additions & 20 deletions tools/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import time
import traceback
from http import HTTPStatus
from typing import Annotated, Any, Literal, Optional
from threading import Lock
from typing import Annotated, Literal, Optional

import numpy as np
import soundfile as sf
Expand Down Expand Up @@ -82,9 +83,7 @@ def __init__(

torch.cuda.synchronize()
logger.info(f"Time to load model: {time.time() - self.t0:.02f} seconds")

if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

if self.compile:
logger.info("Compiling model ...")
Expand All @@ -106,10 +105,9 @@ def __del__(self):


class VQGANModel:
def __init__(self, config_name: str, checkpoint_path: str):
if self.cfg is None:
with initialize(version_base="1.3", config_path="../fish_speech/configs"):
self.cfg = compose(config_name=config_name)
def __init__(self, config_name: str, checkpoint_path: str, device: str):
with initialize(version_base="1.3", config_path="../fish_speech/configs"):
self.cfg = compose(config_name=config_name)

self.model = instantiate(self.cfg.model)
state_dict = torch.load(
Expand All @@ -120,8 +118,9 @@ def __init__(self, config_name: str, checkpoint_path: str):
state_dict = state_dict["state_dict"]
self.model.load_state_dict(state_dict, strict=True)
self.model.eval()
self.model.cuda()
logger.info("Restored model from checkpoint")
self.model.to(device)

logger.info("Restored VQGAN model from checkpoint")

def __del__(self):
self.cfg = None
Expand Down Expand Up @@ -175,7 +174,6 @@ def sematic_to_wav(self, indices):
class LoadLlamaModelRequest(BaseModel):
config_name: str = "text2semantic_finetune"
checkpoint_path: str = "checkpoints/text2semantic-400m-v0.2-4k.pth"
device: str = "cuda"
precision: Literal["float16", "bfloat16"] = "bfloat16"
tokenizer: str = "fishaudio/speech-lm-v1"
compile: bool = True
Expand All @@ -186,15 +184,20 @@ class LoadVQGANModelRequest(BaseModel):
checkpoint_path: str = "checkpoints/vqgan-v1.pth"


class LoadModelRequest(BaseModel):
device: str = "cuda"
llama: LoadLlamaModelRequest
vqgan: LoadVQGANModelRequest


class LoadModelResponse(BaseModel):
name: str


@routes.http.put("/models/{name}")
def load_model(
def api_load_model(
name: Annotated[str, Path("default")],
llama: Annotated[LoadLlamaModelRequest, Body()],
vqgan: Annotated[LoadVQGANModelRequest, Body()],
req: Annotated[LoadModelRequest, Body(exclusive=True)],
) -> Annotated[LoadModelResponse, JSONResponse[200, {}, LoadModelResponse]]:
"""
Load model
Expand All @@ -203,20 +206,25 @@ def load_model(
if name in MODELS:
del MODELS[name]

llama = req.llama
vqgan = req.vqgan

logger.info("Loading model ...")
new_model = {
"llama": LlamaModel(
config_name=llama.config_name,
checkpoint_path=llama.checkpoint_path,
device=llama.device,
device=req.device,
precision=llama.precision,
tokenizer_path=llama.tokenizer,
compile=llama.compile,
),
"vqgan": VQGANModel(
config_name=vqgan.config_name,
checkpoint_path=vqgan.checkpoint_path,
device=req.device,
),
"lock": Lock(),
}

MODELS[name] = new_model
Expand All @@ -225,7 +233,7 @@ def load_model(


@routes.http.delete("/models/{name}")
def delete_model(
def api_delete_model(
name: Annotated[str, Path("default")],
) -> JSONResponse[200, {}, dict]:
"""
Expand All @@ -238,14 +246,16 @@ def delete_model(
content="Model not found.",
)

del MODELS[name]

return JSONResponse(
dict(message="Model deleted."),
200,
)


@routes.http.get("/models")
def list_models() -> JSONResponse[200, {}, dict]:
def api_list_models() -> JSONResponse[200, {}, dict]:
"""
List models
"""
Expand All @@ -271,7 +281,7 @@ class InvokeRequest(BaseModel):


@routes.http.post("/models/{name}/invoke")
def invoke_model(
def api_invoke_model(
name: Annotated[str, Path("default")],
req: Annotated[InvokeRequest, Body(exclusive=True)],
):
Expand All @@ -289,6 +299,9 @@ def invoke_model(
llama_model_manager = model["llama"]
vqgan_model_manager = model["vqgan"]

# Lock
model["lock"].acquire()

device = llama_model_manager.device
seed = req.seed
prompt_tokens = req.prompt_tokens
Expand Down Expand Up @@ -348,6 +361,9 @@ def invoke_model(
codes = codes - 2
assert (codes >= 0).all(), "Codes should be >= 0"

# Release lock
model["lock"].release()

# --------------- llama end ------------
audio, sr = vqgan_model_manager.sematic_to_wav(codes)
# --------------- vqgan end ------------
Expand All @@ -358,8 +374,8 @@ def invoke_model(
return StreamResponse(
iterable=[buffer.getvalue()],
headers={
"Content-Disposition": "attachment; filename=generated.wav",
"Content-Type": "audio/wav",
"Content-Disposition": "attachment; filename=audio.wav",
"Content-Type": "application/octet-stream",
},
)

Expand Down

0 comments on commit 95d90c8

Please sign in to comment.