Skip to content

Commit

Permalink
Merge branch 'main' of github.com:InternLM/lmdeploy into internlm3-dense
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed Jan 14, 2025
2 parents be8fc0f + aa40f0b commit e306c5d
Show file tree
Hide file tree
Showing 17 changed files with 563 additions and 158 deletions.
4 changes: 3 additions & 1 deletion benchmark/profile_pipeline_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def sample_requests(dataset_path: str, num_requests: int,
class Engine:

def __init__(self, model_path: str, engine_config, csv: str):
self.pipe = pipeline(model_path, backend_config=engine_config)
self.pipe = pipeline(model_path,
backend_config=engine_config,
log_level='ERROR')
self.tokenizer = AutoTokenizer.from_pretrained(model_path,
trust_remote_code=True)

Expand Down
16 changes: 11 additions & 5 deletions lmdeploy/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,31 @@ def __init__(self, max_log_len: Optional[int]) -> None:
self.max_log_len = max_log_len

def log_prompt(self, session_id: int, prompt: str) -> None:
if not isinstance(prompt, str):
# Prompt may be a GPT4V message with base64 images;
# logging might be impractical due to length
return
if self.max_log_len is not None:
if prompt is not None:
prompt = prompt[:self.max_log_len]
logger.info(f'session_id={session_id}, '
logger.info(f'session={session_id}, '
f'prompt={prompt!r}')

def log_inputs(self, session_id: int, prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
gen_config: GenerationConfig, adapter_name: str) -> None:
max_log_len = self.max_log_len
input_tokens = len(prompt_token_ids)
if max_log_len is not None:
if prompt is not None:
prompt = prompt[:max_log_len]

if prompt_token_ids is not None:
prompt_token_ids = prompt_token_ids[:max_log_len]

logger.info(f'session_id={session_id}, '
f'prompt={prompt!r}, '
logger.info(f'session={session_id}, '
f'adapter_name={adapter_name}, '
f'input_tokens={input_tokens}, '
f'gen_config={gen_config}, '
f'prompt_token_id={prompt_token_ids}, '
f'adapter_name={adapter_name}.')
f'prompt={prompt!r}, '
f'prompt_token_id={prompt_token_ids}')
6 changes: 5 additions & 1 deletion lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,11 +443,12 @@ def match(cls, model_path: str) -> Optional[str]:
model_path (str): the model path used for matching.
"""
path = model_path.lower()
if all([c not in path for c in ['internlm2', '8k']]) and \
if all([c not in path for c in ['internlm3', 'internlm2', '8k']]) and \
all([c in path for c in ['internlm', 'chat']]):
return 'internlm'


@MODELS.register_module(name='internlm3')
@MODELS.register_module(name='internlm2')
class InternLM2Chat7B(InternLMChat7B):
"""Chat template and generation parameters of InternLM2-Chat-7B."""
Expand Down Expand Up @@ -491,6 +492,9 @@ def match(cls, model_path: str) -> Optional[str]:
if 'internlm2' in path and ('chat' in path or 'math' in path):
return 'internlm2'

if 'internlm3' in path and ('instruct' in path):
return 'internlm3'

def messages2prompt(self,
messages,
sequence_start=True,
Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/pytorch/configurations/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def condition(cls, hf_config):
@classmethod
def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
head_dim = hf_config.hidden_size // hf_config.num_attention_heads
head_dim = getattr(
hf_config, 'head_dim',
hf_config.hidden_size // hf_config.num_attention_heads)
num_attention_heads = hf_config.num_attention_heads
num_key_value_heads = getattr(hf_config, 'num_key_value_heads',
num_attention_heads)
Expand Down
29 changes: 15 additions & 14 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def __init__(self,

# create main thread
self._start_loop()
self._create_buffers()
self._output_stream = torch.cuda.Stream()

@classmethod
Expand Down Expand Up @@ -228,12 +227,6 @@ def _download_adapters(self, adapters: Dict[str, str],

return new_adapters

def _create_buffers(self):
max_batches = self.scheduler_config.max_batches

# buffers to create inputs
self._seq_length_buf = torch.ones(max_batches, dtype=torch.long)

def _build_adapter_manager(self, adapters):
return AdapterManager(adapters)

Expand Down Expand Up @@ -368,14 +361,16 @@ def __update_max_new_tokens(msg):
session_id = req.data['session_id']
sess = self.scheduler.sessions[session_id]
# TODO: support 1 session n sequence
sampling_param = req.data['sampling_param']
return_logits = sampling_param.out_logits
if len(sess.sequences) == 0:
assert len(
req.data['token_ids']) > 0, ('Empty input is not allowed.')
sess.add_sequence(
req.data['token_ids'],
sampling_param=req.data['sampling_param'],
sampling_param=sampling_param,
adapter_name=req.data['adapter_name'],
return_logits=req.data.get('return_logits', False),
return_logits=return_logits,
multimodals=req.data.get('input_multimodals'),
input_embeddings=req.data.get('input_embeddings'),
)
Expand All @@ -391,8 +386,8 @@ def __update_max_new_tokens(msg):
embeddings=req.data.get('input_embeddings'),
)
msg.num_new_tokens = 0
msg.sampling_param = req.data['sampling_param']
msg.return_logits = req.data.get('return_logits', False)
msg.sampling_param = sampling_param
msg.return_logits = return_logits
msg.status = MessageStatus.WAITING
__update_bad_words(msg)
__update_max_new_tokens(msg)
Expand Down Expand Up @@ -431,7 +426,7 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool):
seq_length = [len(tokens) for tokens in token_ids]
seq_length = torch.tensor(seq_length, dtype=torch.long)
else:
seq_length = self._seq_length_buf[:batch_size]
seq_length = torch.ones(batch_size, dtype=torch.long)
max_q_seq_length = seq_length.max().item()

block_offsets = self.scheduler.get_block_tables(messages)
Expand Down Expand Up @@ -685,6 +680,8 @@ async def __long_context_single_forward(inputs):
if not return_logits and not inputs.is_decoding:
last_token_loc = [-1]
ret['hidden_states'] = ret['hidden_states'][:, last_token_loc]
else:
ret['hidden_states'] = ret['hidden_states'].to('cuda')

hidden_states = ret.pop('hidden_states')
logits = self.model_agent.get_logits(hidden_states)
Expand Down Expand Up @@ -808,7 +805,11 @@ def __update_inputs(next_token_ids):
finish = finish or _check_finish(self.scheduler, idx)
event = torch.cuda.Event()
event.record()
output = (next_token_ids, logits, stopped, model_metas, event)
output = dict(next_token_ids=next_token_ids,
logits=logits,
stopped=stopped,
model_metas=model_metas,
event=event)
output_que.put_nowait((finish, output))

inputs.model_metas = model_metas
Expand Down Expand Up @@ -1053,7 +1054,7 @@ async def __step():
finish = False
while not finish:
finish, out = await out_que.get()
step_outputs = await self._make_infer_outputs(*out)
step_outputs = await self._make_infer_outputs(**out)
self._set_has_runable_event(has_runable_event)
resp_que.put_nowait(step_outputs)

Expand Down
126 changes: 17 additions & 109 deletions lmdeploy/pytorch/engine/engine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ async def async_try_add_session(req_sender: RequestSender, session_id: int):
f'with error: {resp.type}'))


async def async_end(req_sender: RequestSender, session_id: int):
"""End the given session."""
req_sender.send_async(RequestType.END_SESSION,
dict(session_id=session_id, response=False))


async def async_cancel(req_sender: RequestSender, session_id: int):
"""Stop current streaming inference."""
resp = await req_sender.async_send(RequestType.STOP_SESSION,
Expand Down Expand Up @@ -158,8 +152,13 @@ async def async_stream_infer(self,
token_ids = resp.data['token_ids'].tolist()
yield EngineOutput(resp.type, token_ids, len(token_ids))
elif resp.type == ResponseType.FINISH:
token_ids = resp.data['token_ids'].tolist()
yield EngineOutput(resp.type, token_ids, len(token_ids))
resp_data = resp.data
token_ids = resp_data['token_ids'].tolist()
logits = resp_data['logits']
yield EngineOutput(resp.type,
token_ids,
len(token_ids),
logits=logits)
break
else:
yield EngineOutput(resp.type, [], 0)
Expand All @@ -183,18 +182,16 @@ async def async_infer(self,
List[int]: The streaming output tokens.
int: The number of the output tokens.
"""
token_ids = []
async for outputs in self.async_stream_infer(session_id,
input_ids,
multimodal=multimodal,
gen_config=gen_config,
**kwargs):
status, tmp_ids = outputs.status, outputs.token_ids
status = outputs.status
if status not in [ResponseType.SUCCESS, ResponseType.FINISH]:
return EngineOutput(status, token_ids, len(token_ids))
token_ids = tmp_ids
return outputs

return EngineOutput(0, token_ids, len(token_ids))
return outputs

def stream_infer(self,
session_id: int,
Expand All @@ -216,9 +213,6 @@ def stream_infer(self,
List[int]: The streaming output tokens.
int: The number of the output tokens.
"""
if len(input_ids) > self.max_input_len:
yield EngineOutput(ResponseType.INPUT_LENGTH_ERROR, [], 0)
return

def __call_async():
"""call async."""
Expand Down Expand Up @@ -255,22 +249,16 @@ def infer(self,
List[int]: The streaming output tokens.
int: The number of the output tokens.
"""
token_ids = []
for outputs in self.stream_infer(session_id,
input_ids,
multimodal=multimodal,
gen_config=gen_config,
**kwargs):
status, tmp_ids = outputs.status, outputs.token_ids
if status not in [ResponseType.SUCCESS, ResponseType.FINISH]:
return EngineOutput(status, token_ids, len(token_ids))
token_ids = tmp_ids

return EngineOutput(0, token_ids, len(token_ids))
return self.req_sender.run_until_complete(
self.async_infer(session_id,
input_ids,
multimodal=multimodal,
gen_config=gen_config,
**kwargs))

async def async_end(self, session_id: int):
"""End the given session."""
return await async_end(self.req_sender, session_id)
return end(self.req_sender, session_id)

def end(self, session_id: int):
"""End the given session."""
Expand All @@ -283,83 +271,3 @@ async def async_cancel(self, session_id: int):
def cancel(self, session_id: int):
"""Stop current streaming inference."""
return cancel(self.req_sender, session_id)

def decode(self,
input_ids,
multimodal: List[InputMultiModalType] = None,
steps: List[int] = None,
sequence_start: bool = True,
sequence_end: bool = True,
adapter_names: List[str] = None):
"""Perform context decode on input tokens.
Args:
input_ids (numpy.ndarray): the batch of input token ids
steps (List[int]): the offset of the k/v cache
multimodal (List[InputMultiModalType]):
multimodals inputs.
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
adapter_names (List[str]): The name of the adapters.
"""
from torch.nn.utils.rnn import pad_sequence
logger.debug('Decoding logits.')
batch_size = len(input_ids)

def __add_messages(session_ids, input_ids, adapter_names,
input_multimodals):
add_msgs = []
sampling_param = SamplingParam(max_new_tokens=0)
batch_size = len(input_ids)
if input_multimodals is None:
input_multimodals = [None] * batch_size
for (session_id, token_id, adapter_name,
in_mm) in zip(session_ids, input_ids, adapter_names,
input_multimodals):
if len(token_id) > self.max_input_len:
raise RuntimeError(
f'Expect input length<={self.max_input_len} '
f'but get {len(token_id)}')
msg = dict(token_ids=token_id,
session_id=session_id,
sampling_param=sampling_param,
adapter_name=adapter_name,
input_multimodals=in_mm,
return_logits=True)
add_msgs.append(msg)
req_types = [RequestType.ADD_MESSAGE] * batch_size
resps = self.req_sender.batched_send_async(req_types,
data=add_msgs)
return resps

if steps is not None:
assert batch_size == len(steps)

if adapter_names is not None:
assert len(adapter_names) == batch_size
else:
adapter_names = [None] * batch_size

session_ids = tuple(range(batch_size))
if sequence_start:
for sid in session_ids:
self.req_sender.send(RequestType.END_SESSION,
dict(session_id=sid))
self._try_add_session(sid)

resps = __add_messages(session_ids, input_ids, adapter_names,
multimodal)

ret = []
for resp in resps:
resp = self.req_sender.recv(resp)
assert resp.type == ResponseType.FINISH
ret.append(resp.data['logits'])

ret = pad_sequence(ret, True)

if sequence_end:
for sid in session_ids:
self.end(sid)

return ret
15 changes: 14 additions & 1 deletion lmdeploy/pytorch/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class SamplingParam:
min_new_tokens: int = 0
response_format: Optional[str] = None
logits_processors: Optional[List[LogitsProcessor]] = None
out_logits: bool = False
out_last_hidden_states: bool = False

@classmethod
def from_gen_config(self, gen_config: GenerationConfig):
Expand All @@ -70,6 +72,16 @@ def from_gen_config(self, gen_config: GenerationConfig):
max_new_tokens = gen_config.max_new_tokens
response_format = gen_config.response_format

output_logits = gen_config.output_logits
if output_logits:
if (output_logits != 'all' or gen_config.max_new_tokens > 0):
output_logits = None
logger.warning(
'Pytorch Engine only support output_logits="all"'
' with max_new_tokens=0')
if gen_config.output_last_hidden_state is not None:
logger.warning(
'Pytorch Engine does not support output last hidden states.')
if top_p < 0 or top_p > 1.0:
logger.warning('`top_p` has to be a float > 0 and < 1'
f' but is {top_p}')
Expand Down Expand Up @@ -110,7 +122,8 @@ def from_gen_config(self, gen_config: GenerationConfig):
response_format=response_format,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
logits_processors=gen_config.logits_processors)
logits_processors=gen_config.logits_processors,
out_logits=(output_logits is not None))


class MessageStatus(enum.Enum):
Expand Down
Loading

0 comments on commit e306c5d

Please sign in to comment.