Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Async pipeline #1763

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
369 changes: 369 additions & 0 deletions opencompass/models/async_openai_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,369 @@
import contextlib
import multiprocessing
import os
import re
from typing import Dict, List, Optional, Union

import jieba
import weakref
from typing import Literal, Tuple, Iterable

from opencompass.utils.prompt import PromptList
from opencompass.models.base_api import AsyncTokenBucket, BaseAPIModel

import threading
import asyncio
from typing import cast
from contextlib import contextmanager


PromptType = Union[PromptList, str]
OPENAI_API_BASE = os.path.join(
os.environ.get('OPENAI_BASE_URL', 'https://api.openai.com/v1/'),
'chat/completions')


class _APIModelState:
_instance: Dict[str, weakref.ReferenceType["_APIModelState"]] = {}
_count: int
_concurrency: int
_locks = [threading.Lock(), multiprocessing.Lock()]

def __init__(self, *, name: str, concurrency: int, query_per_second=1) -> None:
self._name = name
self._count = 0
self._concurrency = concurrency
self._token_bucket = AsyncTokenBucket(rate=query_per_second)

self._count += 1
self._concurrency = max(1, self._concurrency // self._count)

@property
def concurrency(self) -> int:
# If update and concurrency are called simultaneously, the values
# returned here may be inaccurate, but the impact is likely minimal
return self._concurrency

async def acquire(self):
return await self._token_bucket.acquire()

@property
def rpm(self):
return self._token_bucket.rpm

@property
def name(self) -> str:
return self._name

@property
def count(self):
return self._count

@classmethod
def _cleanup(cls, ref: weakref.ReferenceType["_APIModelState"]):
with cls._lock():
self: _APIModelState = ref() # type: ignore
cls._instance.pop(self._name)

def __new__(cls, name: str, *args, **kwargs) -> "_APIModelState":
with cls._lock():
if name not in cls._instance:
self = super().__new__(cls)
cls._instance[name] = weakref.ref(self, cls._cleanup)
return cls._instance[name]() # type: ignore

@classmethod
@contextmanager
def _lock(cls):
with contextlib.ExitStack() as stack:
[stack.enter_context(lock) for lock in cls._locks]
yield



class AsyncOpenAISDK(BaseAPIModel):
states: Dict[str, _APIModelState] = {}

def __init__(
self,
path: str = 'gpt-3.5-turbo',
max_seq_len: int | None = None, # type: ignore
query_per_second: int = 1,
retry: int = 2,
key: str = 'ENV',
org: str | List[str] | None = None,
meta_template: Dict | None = None,
openai_api_base: str = OPENAI_API_BASE,
openai_proxy_url: Optional[str] = None,
mode: Literal['none', 'front', 'mid', 'rear'] = 'none',
logprobs: bool | None = False,
top_logprobs: int | None = None,
temperature: float | None = None,
tokenizer_path: str | None = None,
extra_body: Dict | None = None,
max_completion_tokens: int = 16384,
verbose: bool = False,
concurrency: int = 64,
status_code_mappings: dict = {},
):
from openai import AsyncOpenAI

assert mode in ['none', 'front', 'mid', 'rear']
self.mode = mode
state_key = self._get_state_key(api_base=openai_api_base, model_name=path)
if state_key not in AsyncOpenAISDK.states:
AsyncOpenAISDK.states[path] = _APIModelState(
name=state_key,
concurrency=concurrency,
query_per_second=query_per_second,
)
self.state = AsyncOpenAISDK.states[path]
self.openai_client = AsyncOpenAI(base_url=openai_api_base, api_key=key)

if max_seq_len is None:
if '16k' in path:
max_seq_len = 16384
elif 'gpt-4' in path:
max_seq_len = 8192
elif 'gpt-3.5' in path:
max_seq_len = 4097
else:
max_seq_len = 32768
else:
max_seq_len = max_seq_len

super().__init__(path=path, max_seq_len=max_seq_len, meta_template=meta_template, retry=retry)

self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.tokenizer_path = tokenizer_path
self.hf_tokenizer = None
self.extra_body = extra_body
self.max_completion_tokens = max_completion_tokens
self.temperature = temperature
self.openai_api_base = openai_api_base
self.concurrency = concurrency

self.status_code_mappings = status_code_mappings

if openai_proxy_url == 'ENV':
if 'OPENAI_PROXY_URL' not in os.environ:
raise ValueError('OPENAI_PROXY_URL is not set.')
self.proxy_url = os.getenv('OPENAI_PROXY_URL')
else:
self.proxy_url = openai_proxy_url

async def generate(self, # type: ignore
inputs: Iterable[PromptType],
max_out_len: int = 512,
temperature: float = 0.7,
**kwargs) -> List[str]:
"""Generate results given a list of inputs.

Args:
inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
temperature (float): What sampling temperature to use,
between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more
focused and deterministic. Defaults to 0.7.

Returns:
List[str]: A list of generated strings.
"""
if self.temperature is not None:
temperature = self.temperature

# TODO: This should be an AsyncGenerator if an real `AsyncInference` has been implemented
tasks_queue: List[asyncio.Future] = []
results_queue: List[Tuple[int, str]] = []
inputs_iter = enumerate(inputs)

data_stop = False
while not (data_stop and not tasks_queue):
concurrency = self.state.concurrency

if tasks_queue:
done, pending = await asyncio.wait(tasks_queue, return_when=asyncio.FIRST_COMPLETED)
tasks_queue = list(pending)
for queue in done:
result: Tuple[int, str] = queue.result()
results_queue.append(result)

while not data_stop and len(tasks_queue) < concurrency:
try:
index, _input = next(inputs_iter)
except StopIteration:
data_stop = True
break
tasks_queue.append(
asyncio.create_task(
self._generate(
input=_input,
max_out_len=self.max_completion_tokens or max_out_len,
temperature=temperature,
index=index,
)
)
)
results_queue.sort()
return [item[1] for item in results_queue]

async def generate_from_template(self, templates: List[PromptType], # type: ignore
max_out_len: int, **kwargs):
"""Generate completion from a list of templates.

Args:
templates (List[PromptType]): A list of templates.
max_out_len (int): The maximum length of the output.
"""
inputs = self.parse_template(templates, mode='gen') # type: ignore
return await self.generate(inputs, max_out_len=max_out_len, **kwargs)

async def _generate(self, input: PromptList | str, max_out_len: int,
temperature: float, index: int) -> Tuple[int, str]:
from openai import APIStatusError, BadRequestError
assert isinstance(input, (str, PromptList))

# max num token for gpt-3.5-turbo is 4097
# Most models' token limits are above 32k

# will leave 100 tokens as prompt buffer, triggered if input is str
if isinstance(input, str) and self.mode != 'none':
context_window = self.max_seq_len
input = self.bin_trim(
input,
context_window - 100 - max_out_len,
cast(Literal['front', 'mid', 'rear'], self.mode),
)

if isinstance(input, str):
messages = [{'role': 'user', 'content': input}]
else:
messages = []
for item in input:
msg = {'content': item['prompt']}
if item['role'] == 'HUMAN':
msg['role'] = 'user'
elif item['role'] == 'BOT':
msg['role'] = 'assistant'
elif item['role'] == 'SYSTEM':
msg['role'] = 'system'
messages.append(msg)


# Hold out 100 tokens due to potential errors in tiktoken calculation
# try:
# max_out_len = min(
# max_out_len,
# context_window - self.get_token_len(str(input)) - 100)
# except KeyError:
# max_out_len = max_out_len
# if max_out_len <= 0:
# return ''

num_retries = 0
while num_retries < self.retry:
await self.state.acquire()

query_data = dict(
model=self.path,
max_tokens=max_out_len,
n=1,
temperature=self.temperature,
messages=messages,
extra_body=self.extra_body,
timeout=600,
)

try:
if self.verbose:
self.logger.info('Start calling OpenAI API')
responses = await self.openai_client.chat.completions.create(**query_data)

if self.verbose:
self.logger.info(
'Successfully get response from OpenAI API')
try:
self.logger.info(responses)
except Exception as e: # noqa F841
pass
if not responses.choices:
self.logger.error(
'Response is empty, it is an internal server error \
from the API provider.')
return index, responses.choices[0].message.content

except (BadRequestError, APIStatusError) as e:
# Handle BadRequest status
# You can specify self.status_code_mappings to bypass \
# API sensitivity blocks
# For example: status_code_mappings={400: 'Input data \
# may contain inappropriate content.'}
status_code = e.status_code
if (status_code is not None
and status_code in self.status_code_mappings):
error_message = self.status_code_mappings[status_code]
self.logger.info(f'Status Code: {status_code},\n'
f'Original Error Message: {e},\n'
f'Return Message: {error_message} ')
return index, error_message
else:
self.logger.warning(f"Failed to get response for {e}, retry {num_retries}/{self.retry}")
except Exception as e:
self.logger.warning(f"Failed to get response for {e}, retry {num_retries}/{self.retry}")
num_retries += 1
raise RuntimeError('Calling OpenAI API failed after retrying for '
f'{self.retry} times. Check the logs for details.')

def _get_state_key(self, api_base: str, model_name: str):
return api_base + model_name

def bin_trim(self, prompt: str, num_token: int, mode: Literal['front', 'mid', 'rear']) -> str:
"""Get a suffix of prompt which is no longer than num_token tokens.

Args:
prompt (str): Input string.
num_token (int): The upper bound of token numbers.

Returns:
str: The trimmed prompt.
"""
token_len = self.get_token_len(prompt)
if token_len <= num_token:
return prompt
pattern = re.compile(r'[\u4e00-\u9fa5]')
if pattern.search(prompt):
words = list(jieba.cut(prompt, cut_all=False))
sep = ''
else:
words = prompt.split(' ')
sep = ' '

l, r = 1, len(words)
while l + 2 < r:
# mode: Literal['front', 'mid', 'rear'] = self.mode
mid = (l + r) // 2
if mode == 'front':
cur_prompt = sep.join(words[-mid:])
elif mode == 'mid':
cur_prompt = sep.join(words[:mid]) + sep.join(words[-mid:])
elif mode == 'rear':
cur_prompt = sep.join(words[:mid])

if self.get_token_len(cur_prompt) <= num_token:
l = mid # noqa: E741
else:
r = mid

if self.mode == 'front':
prompt = sep.join(words[-l:])
elif self.mode == 'mid':
prompt = sep.join(words[:l]) + sep.join(words[-l:])
elif self.mode == 'rear':
prompt = sep.join(words[:l])
return prompt


Loading
Loading