Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimizations for process output step #104

Open
wants to merge 1 commit into
base: MLPerf_4.1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup_cython.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"vllm/model_executor/layers/sampler.py",
"vllm/sampling_params.py",
"vllm/utils.py",
"vllm/block.py",
]


Expand Down
12 changes: 7 additions & 5 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,12 +612,14 @@ def _free_block_table(self, block_table: BlockTable) -> None:
self.cpu_allocator.free(block)

def free(self, seq: Sequence) -> None:
if seq.seq_id not in self.block_tables:
# Already freed or haven't been scheduled yet.
return
block_table = self.block_tables[seq.seq_id]
seq_id = seq.seq_id
block_table = self.block_tables.pop(seq_id,[])
#if seq.seq_id not in self.block_tables:
# # Already freed or haven't been scheduled yet.
# return
#block_table = self.block_tables[seq.seq_id]
self._free_block_table(block_table)
del self.block_tables[seq.seq_id]
#del self.block_tables[seq.seq_id]

def free_cross(self, seq_group: SequenceGroup) -> None:
if seq_group.request_id not in self.cross_block_tables:
Expand Down
22 changes: 13 additions & 9 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,12 +681,13 @@ def _process_model_outputs(
"""

now = time.time()

# Organize outputs by [sequence group][step] instead of
# [step][sequence group].
output_by_sequence_group = create_output_by_sequence_group(
output, num_seq_groups=len(scheduled_seq_groups))

seq_groups = [scheduled_seq_group.seq_group for scheduled_seq_group in scheduled_seq_groups]

# Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs, seq_group_meta in zip(
scheduled_seq_groups, output_by_sequence_group,
Expand All @@ -708,14 +709,17 @@ def _process_model_outputs(
# Create the outputs.
request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
for scheduled_seq_group in scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
for seq_group in ignored_seq_groups:
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
[seq_group.maybe_set_first_token_time(now) for seq_group in seq_groups]
request_outputs = [RequestOutputFactory.create(seq_group) for seq_group in seq_groups]
#for scheduled_seq_group in scheduled_seq_groups:
# seq_group = scheduled_seq_group.seq_group
# seq_group.maybe_set_first_token_time(now)
# request_output = RequestOutputFactory.create(seq_group)
# request_outputs.append(request_output)
request_outputs.extend([RequestOutputFactory.create(seq_group) for seq_group in ignored_seq_groups])
#for seq_group in ignored_seq_groups:
# request_output = RequestOutputFactory.create(seq_group)
# request_outputs.append(request_output)
return request_outputs

def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Expand Down
24 changes: 24 additions & 0 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,30 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None:
# Process samples
samples = outputs.samples
if len(samples)==1:
#if there's only 1 sample, it has to be from 1 running seq in seq group
parent_seq = next(iter(seq_group.seqs_dict.values()))
child_sample = samples[0]
if not seq_group.sampling_params.use_beam_search:
#fastpath
parent_seq.append_token_id(child_sample.output_token,
child_sample.logprobs)
if self.detokenizer and seq_group.sampling_params.detokenize:
new_char_count = self.detokenizer.decode_sequence_inplace(
parent_seq, seq_group.sampling_params)
else:
new_char_count = 0

stopped = self.stop_checker.maybe_stop_sequence(
parent_seq,
new_char_count,
seq_group.sampling_params,
lora_req=seq_group.lora_request,
)
#if parent_seq.is_finished():
if stopped:
self.scheduler.free_seq(parent_seq)
return
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict: Dict[int, List[SequenceOutput]] = {
Expand Down
24 changes: 13 additions & 11 deletions vllm/engine/output_processor/stop_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def maybe_stop_sequence(
new_char_count: int,
sampling_params: SamplingParams,
lora_req: Optional[LoRARequest] = None,
) -> None:
) -> bool:
"""Stop the finished sequences.

new_char_count is the number of chars added to the
Expand All @@ -42,49 +42,51 @@ def maybe_stop_sequence(

# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.get_output_len() < sampling_params.min_tokens:
return
outlen = seq.get_output_len()
if outlen < sampling_params.min_tokens:
return False

last_token_id = seq.get_last_token_id()
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
and last_token_id == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
return
return True

# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
return True

# Check if any stop strings are matched.
stop_str = self._check_stop_strings(seq, new_char_count,
sampling_params)
if stop_str is not None:
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
return True

# Check if the sequence has reached max_model_len.
if seq.get_len() > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
return True

# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
if outlen == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
return True
return False

@staticmethod
def _check_stop_strings(seq: Sequence, new_char_count: int,
Expand Down
2 changes: 2 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def update_num_computed_tokens(self, num_new_computed_tokens: int):
self._num_computed_tokens += num_new_computed_tokens
assert self._num_computed_tokens <= self.get_len(), (
self._num_computed_tokens, self.get_len())
if self._stage == SequenceStage.DECODE:
return
# If all tokens are computed, it means it is in decoding phase.
if self.get_num_uncomputed_tokens() == 0:
self._stage = SequenceStage.DECODE
Expand Down