diff --git a/benchmark/profile_pipeline_api.py b/benchmark/profile_pipeline_api.py index 764f78399..917867832 100644 --- a/benchmark/profile_pipeline_api.py +++ b/benchmark/profile_pipeline_api.py @@ -69,14 +69,14 @@ def __init__(self, model_path: str, engine_config, csv: str): def process_request(self, requests, concurrency, temperature, top_p, top_k, stream_output): - stats = OrderedDict( - (session_id, None) for session_id in range(len(requests))) + stats = OrderedDict((index, None) for index in range(len(requests))) prompts = [prompt for prompt, _, _ in requests] gen_configs = [ GenerationConfig(temperature=temperature, top_p=top_p, top_k=top_k, ignore_eos=True, + do_sample=True, max_new_tokens=output_len) for _, _, output_len in requests ] @@ -87,10 +87,10 @@ def process_request(self, requests, concurrency, temperature, top_p, top_k, for output in self.pipe.stream_infer(prompts, gen_configs, do_preprocess=False): - session_id = output.session_id + index = output.index n_token = output.generate_token_len finish_reason = output.finish_reason - stats[session_id] = (n_token, finish_reason) + stats[index] = (n_token, finish_reason) if finish_reason is not None: pbar.update(1) else: @@ -98,20 +98,20 @@ def process_request(self, requests, concurrency, temperature, top_p, top_k, gen_configs, do_preprocess=False, use_tqdm=True): - session_id = output.session_id + index = output.index n_token = output.generate_token_len finish_reason = output.finish_reason - stats[session_id] = (n_token, finish_reason) + stats[index] = (n_token, finish_reason) elapsed_time = time.perf_counter() - start completion_tokens = 0 - for session_id, (n_token, finish_reason) in stats.items(): + for index, (n_token, finish_reason) in stats.items(): assert finish_reason == 'length', \ - f'unexpected finish_reason of session_id={session_id}, ' \ - f'prompt={requests[session_id][0]}' - assert n_token - 1 <= requests[session_id][-1] <= n_token, \ - f'request to generate {requests[session_id][-1]} tokens, ' \ + f'unexpected finish_reason of index={index}, ' \ + f'prompt={requests[index][0]}' + assert n_token - 1 <= requests[index][-1] <= n_token, \ + f'request to generate {requests[index][-1]} tokens, ' \ f'but got {n_token} tokens' completion_tokens += n_token diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 813a6acd2..4f04906f1 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -335,7 +335,6 @@ class Response: text: str generate_token_len: int input_token_len: int - session_id: int finish_reason: Optional[Literal['stop', 'length']] = None token_ids: List[int] = field(default_factory=list) logprobs: List[Dict[int, float]] = None diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index cb775ab2d..e9083fc11 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -1,22 +1,27 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio +import concurrent.futures import dataclasses import json import os import random import re -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, closing from copy import deepcopy +from functools import partial from itertools import count -from queue import Empty, Queue +from queue import Queue from threading import Thread -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import (Any, AsyncIterator, Dict, Iterator, List, Literal, + Optional, Tuple, Union) + +import tqdm from lmdeploy.logger import RequestLogger from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig, Response, ResponseType, TurbomindEngineConfig) from lmdeploy.model import MODELS, ChatTemplateConfig, best_match_model -from lmdeploy.serve.utils import LogitsMixin, _get_event_loop +from lmdeploy.serve.utils import LogitsMixin from lmdeploy.tokenizer import DetokenizeState from lmdeploy.utils import _get_and_verify_max_len, _stop_words, get_logger @@ -52,6 +57,33 @@ class GenOut: logprobs: List[Dict[int, float]] = None +def _gen_out_to_response(out: GenOut, index) -> Response: + return Response(text=out.response, + generate_token_len=out.generate_token_len, + input_token_len=out.input_token_len, + finish_reason=out.finish_reason, + token_ids=out.token_ids, + logprobs=out.logprobs, + index=index) + + +def _append_response(dst: Response, src: Response): + """dst += src.""" + if not dst: + return src + dst.text += src.text + dst.generate_token_len = src.generate_token_len + dst.input_token_len = src.input_token_len + dst.finish_reason = src.finish_reason + dst.index = src.index + if src.token_ids: + dst.token_ids += src.token_ids + if src.logprobs: + dst.logprobs = dst.logprobs or [] + dst.logprobs += src.logprobs + return dst + + class Session: """Session for AsyncEngine.chat. @@ -63,14 +95,17 @@ class Session: _engine (Any): engine for internal use. history (List[Any, str]): chat history. """ - _ids = count(0) - def __init__(self): - self._id: int = next(self._ids) + def __init__(self, + session_id: int, + engine: Any, + gen_config: GenerationConfig = None): + self._id: int = session_id + self._engine = engine self._step: int = 0 self._prompt: Any = None self._response: Response = None - self._engine: Any = None + self._gen_config = gen_config self.history: List[Tuple[Any, str]] = [] def _merge_response(self, resp: Response, step: Union[Response, GenOut]): @@ -89,8 +124,8 @@ def response(self) -> Response: def close(self): """release engine storage for this session.""" if self._engine: - inst = self._engine.create_instance() - inst.end(self._id) + self._engine._run(coro=self._engine.end_session(self._id)).result() + self._engine = None def __repr__(self) -> str: res = '' @@ -100,6 +135,60 @@ def __repr__(self) -> str: res += f'USER:\n{user}\nASSISTANT:\n{assistant}\n' return res + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __call__( + self, + prompt: str, + gen_config: Optional[GenerationConfig] = None, + stream_response: bool = True, + do_preprocess: bool = True) -> Union[Response, Iterator[Response]]: + self._engine.chat(prompt=prompt, + gen_config=gen_config or self._gen_config, + stream_response=stream_response, + do_preprocess=do_preprocess, + session=self) + if stream_response: + return self.generator + else: + return self.response + + +class _EventLoopThread: + + def __init__(self): + fut = concurrent.futures.Future() + self.thread = Thread( + target=partial(_EventLoopThread._thread_entry, fut)) + self.thread.start() + self.loop: asyncio.AbstractEventLoop = fut.result() + self.closed = False + + @staticmethod + def _thread_entry(fut): + loop = asyncio.new_event_loop() + fut.set_result(loop) + try: + loop.run_forever() + except BaseException as e: + logger.error(f'[internal_thread] {type(e).__name__} {e}') + finally: + loop.close() + + def close(self): + if self.closed: + return + self.closed = True + self.loop.call_soon_threadsafe(self.loop.stop) + self.thread.join() + + def __del__(self): + self.close() + class AsyncEngine(LogitsMixin): """Async inference engine. Maintaining a bunch of tm_model instances. @@ -179,13 +268,26 @@ def __init__(self, self.instance_num = self.backend_config.max_batch_size self.tokenizer = self.engine.tokenizer self.id2step = {} - self.id2generator = {} - self.free_gens: asyncio.Queue = None + self.id2inst = {} + self.free_insts: asyncio.Queue = None self.instances = [ self.engine.create_instance() for _ in range(self.instance_num) ] self._session_id = count(0) self.request_logger = RequestLogger(max_log_len) + self.internal_thread = _EventLoopThread() + self.limiter: asyncio.Semaphore = None + + def close(self): + self.internal_thread.close() + + def _get_free_insts(self): + if self.free_insts is None: + # `asyncio.Queue` must be created in an async context + self.free_insts = asyncio.Queue() + for inst in self.instances: + self.free_insts.put_nowait(inst) + return self.free_insts def _build_turbomind( self, @@ -245,27 +347,117 @@ def __call__(self, async def stop_session(self, session_id: int): """Stop a session by a session_id.""" - generator = self.id2generator.get(session_id) + generator = self.id2inst.get(session_id) if generator: await generator.async_cancel(session_id) # else it's not running at all async def end_session(self, session_id: int): """For ending a session that is not running.""" - generator = self.id2generator.get(session_id) - if generator: - fut = generator._fut - await fut - assert session_id not in self.id2generator - else: - generator = await self.free_gens.get() + inst = self.id2inst.get(session_id) + if inst: + await inst._active.wait() + assert session_id not in self.id2inst + inst = await self._get_free_insts().get() try: - await generator.async_end(session_id) + await inst.async_end(session_id) self.id2step[session_id] = 0 except (Exception, asyncio.CancelledError, GeneratorExit) as e: # noqa logger.error(f'[end_session] exception caught: {e}') finally: - self.free_gens.put_nowait(generator) + self._get_free_insts().put_nowait(inst) + + def _get_limiter(self): + if not self.limiter: + self.limiter = asyncio.Semaphore(self.instance_num) + return self.limiter + + async def _async_infer(self, requests: AsyncIterator[Dict], + **kwargs) -> AsyncIterator[AsyncIterator[Response]]: + async for req in requests: + gen = self.generate(**req, **kwargs) + yield gen + + def _infer(self, + requests: Iterator[Dict], + multiplex: bool, + pbar=None, + loop=None) -> Iterator[Iterator[Response]]: + + async def _sync_resp(g, que: Queue, idx: int, sem: asyncio.Semaphore): + async for out in g: + que.put(_gen_out_to_response(out, idx)) + sem.release() + if not multiplex: + que.put(None) # sentinel of inner generator + if pbar: + pbar.update(1) + + que = Queue() + + async def _infer(): + sem = self._get_limiter() + tasks = [] + for idx, req in enumerate(requests): + await sem.acquire() + gen = self.generate(**req) + dst = que if multiplex else Queue() + if not multiplex: + que.put(iter(dst.get, None)) + # create a task to send the responses + task = asyncio.create_task(_sync_resp(gen, dst, idx, sem)) + tasks.append(task) + if not multiplex: # sentinel of outer generator + que.put(None) + await asyncio.gather(*tasks) + if multiplex: + que.put(None) # sentinel of inner generator + + loop = loop or self.internal_thread.loop + # submit the coroutine to async world + asyncio.run_coroutine_threadsafe( + _infer(), loop).add_done_callback(lambda x: x.result()) + + return iter(que.get, None) + + @staticmethod + def _is_single(prompts): + return isinstance(prompts, str) or isinstance(prompts[0], Dict) + + def infer(self, + prompts: Union[List[str], str, List[Dict], List[List[Dict]]], + gen_config: Optional[Union[GenerationConfig, + List[GenerationConfig]]] = None, + do_preprocess: bool = True, + adapter_name: Optional[str] = None, + stream_response: bool = False, + multiplex: bool = False, + pbar: Optional[tqdm.tqdm] = None, + **kwargs): + + prompts = [prompts] if AsyncEngine._is_single(prompts) else prompts + assert isinstance(prompts, List), 'prompts should be a list' + gen_config = gen_config or GenerationConfig() + if not isinstance(gen_config, List): + gen_config = [gen_config] * len(prompts) + assert len(prompts) == len(gen_config), \ + 'input gen_confg length differs from the length of prompts' # noqa + + def requests(): + for prompt, gen_cfg in zip(prompts, gen_config): + r = dict(messages=prompt, + gen_config=gen_cfg, + do_preprocess=do_preprocess, + adapter_name=adapter_name, + stream_response=stream_response, + **kwargs) + r.setdefault('sequence_start', True) + r.setdefault('sequence_end', True) + if 'session_id' not in r: + r['session_id'] = next(self._session_id) + yield r + + return self._infer(requests(), multiplex, pbar) def batch_infer(self, prompts: Union[List[str], str, List[Dict], @@ -290,59 +482,26 @@ def batch_infer(self, Pick one from adapters. Default to None, using the base model. use_tqdm (bool): Whether use the progress bar. Default to False """ - need_list_wrap = isinstance(prompts, str) or isinstance( - prompts[0], Dict) - prompts = [prompts] if need_list_wrap else prompts - assert isinstance(prompts, List), 'prompts should be a list' - if gen_config is None: - gen_config = GenerationConfig() - if not isinstance(gen_config, List): - gen_config = [gen_config] * len(prompts) - assert len(prompts) == len(gen_config), \ - 'input gen_confg length differs from the length of prompts' # noqa - prompt_num = len(prompts) - session_ids = [next(self._session_id) for _ in range(prompt_num)] - outputs = [ - Response('', 0, 0, session_ids[i], index=i) - for i in range(prompt_num) - ] - generators = [] - if use_tqdm: - import tqdm - pbar = tqdm.tqdm(total=len(prompts)) - for i, prompt in enumerate(prompts): - generators.append( - self.generate(prompt, - session_ids[i], - gen_config=gen_config[i], - stream_response=True, - sequence_start=True, - sequence_end=True, - do_preprocess=do_preprocess, - adapter_name=adapter_name, - **kwargs)) - - async def _inner_call(i, generator): - async for out in generator: - outputs[i].text += out.response - outputs[i].generate_token_len = out.generate_token_len - outputs[i].input_token_len = out.input_token_len - outputs[i].finish_reason = out.finish_reason - if out.token_ids: - outputs[i].token_ids.extend(out.token_ids) - if out.logprobs: - if outputs[i].logprobs is None: - outputs[i].logprobs = [] - outputs[i].logprobs.extend(out.logprobs) - if use_tqdm and out.finish_reason is not None: - pbar.update(1) - - async def gather(): - await asyncio.gather( - *[_inner_call(i, generators[i]) for i in range(len(prompts))]) - - _get_event_loop().run_until_complete(gather()) - outputs = outputs[0] if need_list_wrap else outputs + is_single = AsyncEngine._is_single(prompts) + outputs = [] + pbar = tqdm.tqdm( + total=1 if is_single else len(prompts)) if use_tqdm else None + try: + for g in self.infer(prompts, + gen_config, + do_preprocess, + adapter_name, + stream_response=False, + pbar=pbar, + **kwargs): + res = None + for out in g: + res = _append_response(res, out) + outputs.append(res) + finally: + if pbar: pbar.close() # noqa + if is_single: + return outputs[0] return outputs def stream_infer( @@ -352,6 +511,7 @@ def stream_infer( List[GenerationConfig]]] = None, do_preprocess: bool = True, adapter_name: Optional[str] = None, + stream_response: bool = True, **kwargs): """Inference a batch of prompts with stream mode. @@ -366,62 +526,13 @@ def stream_infer( adapter_name (str): the adapter name of slora for pytorch backend. Pick one from adapters. Default to None, using the base model. """ - need_list_wrap = isinstance(prompts, str) or isinstance( - prompts[0], Dict) - prompts = [prompts] if need_list_wrap else prompts - assert isinstance(prompts, List), 'prompts should be a list' - if gen_config is None: - gen_config = GenerationConfig() - if not isinstance(gen_config, List): - gen_config = [gen_config] * len(prompts) - assert len(prompts) == len(gen_config), \ - 'input gen_confg length differs from the length of prompts' # noqa - session_ids = [next(self._session_id) for _ in range(len(prompts))] - outputs = Queue() - generators = [] - for i, prompt in enumerate(prompts): - generators.append( - self.generate(prompt, - session_ids[i], - gen_config=gen_config[i], - stream_response=True, - sequence_start=True, - sequence_end=True, - do_preprocess=do_preprocess, - adapter_name=adapter_name, - **kwargs)) - - async def _inner_call(i, generator): - async for out in generator: - outputs.put( - Response(out.response, - out.generate_token_len, - out.input_token_len, - session_ids[i], - out.finish_reason, - out.token_ids, - out.logprobs, - index=i)) - - async def gather(): - await asyncio.gather( - *[_inner_call(i, generators[i]) for i in range(len(prompts))]) - outputs.put(None) - - loop = _get_event_loop() - proc = Thread(target=lambda: loop.run_until_complete(gather())) - proc.start() - - while True: - try: - out = outputs.get(timeout=0.001) - if out is None: - break - yield out - except Empty: - pass - - proc.join() + return self.infer(prompts, + gen_config, + do_preprocess, + adapter_name, + stream_response, + multiplex=True, + **kwargs) async def _get_prompt_input(self, prompt: str, @@ -448,17 +559,17 @@ async def _get_prompt_input(self, @asynccontextmanager async def model_inst(self, session_id: int): """A context manager to make sure server's safe running.""" - assert session_id not in self.id2generator - inst = await self.free_gens.get() - inst._fut = asyncio.get_running_loop().create_future() - self.id2generator[session_id] = inst + assert session_id not in self.id2inst + free_insts = self._get_free_insts() + inst = await free_insts.get() + inst._active = asyncio.Event() + self.id2inst[session_id] = inst try: yield inst finally: - self.id2generator.pop(session_id) - inst._fut.set_result(None) - inst._fut = None - self.free_gens.put_nowait(inst) + self.id2inst.pop(session_id) + inst._active.set() + free_insts.put_nowait(inst) @asynccontextmanager async def safe_run(self, inst, session_id, **kwargs): @@ -575,12 +686,6 @@ async def generate( def is_error(status): return status not in [ResponseType.SUCCESS, ResponseType.FINISH] - if self.free_gens is None: - # `asyncio.Queue` must be created in an async context - self.free_gens = asyncio.Queue() - for inst in self.instances: - self.free_gens.put_nowait(inst) - async with self.model_inst(session_id) as inst: state = DetokenizeState(len(input_ids)) token_ids = input_ids.copy() @@ -699,12 +804,28 @@ def parse_tool_response(self, text, tools, **kwargs): for call_info in call_info_list] return text, call_info_list + def _run(self, fn=None, coro=None, loop=None): + assert (fn or coro) and not (fn and coro) + loop = loop or self.internal_thread.loop + if fn: + + async def _coro(): + return fn() + + coro = _coro() + return asyncio.run_coroutine_threadsafe(coro, loop) + + def session(self, gen_config: GenerationConfig = None): + return Session(self._run(fn=lambda: next(self._session_id)).result(), + engine=self, + gen_config=gen_config) + def chat(self, prompt: str, session=None, gen_config: Optional[GenerationConfig] = None, - do_preprocess: bool = True, - **kwargs) -> Session: + stream_response=False, + **kwargs) -> Union[Session, Iterator]: """Chat. Args: @@ -717,8 +838,7 @@ def chat(self, **kwargs (dict): ad hoc parametrization of `gen_config """ if session is None: - session = Session() - session._engine = self.engine + session = self.session() # sync & init session._prompt = prompt @@ -726,25 +846,35 @@ def chat(self, sequence_start = session._step == 0 - async def _work(): - resp = Response('', -1, -1, session._id) - async for output in self.generate(prompt, - session_id=session._id, - gen_config=gen_config, - stream_response=False, - sequence_start=sequence_start, - sequence_end=False, - step=session._step, - do_preprocess=do_preprocess, - **kwargs): - resp = session._merge_response(resp, output) - return resp - - from lmdeploy.pytorch.engine.request import _run_until_complete - resp = _run_until_complete(_work()) - - session._response = resp - session._step += resp.generate_token_len + resp.input_token_len - session.history.append((session._prompt, resp.text)) + generator = self.infer(prompt, + gen_config, + sequence_start=sequence_start, + sequence_end=False, + session_id=session._id, + stream_response=stream_response, + multiplex=True) + + def _gen(): + resp = None + try: + for out in generator: + resp = _append_response(resp, out) + yield out + except: # noqa + self._run(coro=self.stop_session(session._id)).result() + raise + else: + session._response = resp + session._step += resp.generate_token_len + resp.input_token_len + session.history.append((session._prompt, resp.text)) + + if stream_response: + session.generator = _gen() + else: + # run the generator until finish + with closing(_gen()) as gen: + for _ in gen: + pass + session.generator = None return session diff --git a/src/turbomind/kernels/gpt_kernels.cu b/src/turbomind/kernels/gpt_kernels.cu index a0c47fff0..d611cfab4 100644 --- a/src/turbomind/kernels/gpt_kernels.cu +++ b/src/turbomind/kernels/gpt_kernels.cu @@ -315,7 +315,7 @@ void invokeTranspose2D_(T* dst, const T* src, int rows, int cols, cudaStream_t s dim3 grid((cols + TILE_DIM - 1) / TILE_DIM, // (rows + TILE_DIM - 1) / TILE_DIM); bool swap_xy = false; - + if (grid.y > 65535) { // max dim for grid.y std::swap(grid.x, grid.y); swap_xy = true;