From b9013696b23dde372cccecdbaf69f0c852008844 Mon Sep 17 00:00:00 2001 From: Shomy Date: Thu, 25 Jul 2024 20:54:15 +0000 Subject: [PATCH] optimizations for process output step --- setup_cython.py | 1 + vllm/core/block_manager_v1.py | 12 ++++++---- vllm/engine/llm_engine.py | 22 ++++++++++-------- vllm/engine/output_processor/single_step.py | 24 ++++++++++++++++++++ vllm/engine/output_processor/stop_checker.py | 24 +++++++++++--------- vllm/sequence.py | 2 ++ 6 files changed, 60 insertions(+), 25 deletions(-) diff --git a/setup_cython.py b/setup_cython.py index 5ea5c39b4e2c5..8ae3e9787749e 100644 --- a/setup_cython.py +++ b/setup_cython.py @@ -24,6 +24,7 @@ "vllm/model_executor/layers/sampler.py", "vllm/sampling_params.py", "vllm/utils.py", + "vllm/block.py", ] diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 201cba309f6ef..f838446681601 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -612,12 +612,14 @@ def _free_block_table(self, block_table: BlockTable) -> None: self.cpu_allocator.free(block) def free(self, seq: Sequence) -> None: - if seq.seq_id not in self.block_tables: - # Already freed or haven't been scheduled yet. - return - block_table = self.block_tables[seq.seq_id] + seq_id = seq.seq_id + block_table = self.block_tables.pop(seq_id,[]) + #if seq.seq_id not in self.block_tables: + # # Already freed or haven't been scheduled yet. + # return + #block_table = self.block_tables[seq.seq_id] self._free_block_table(block_table) - del self.block_tables[seq.seq_id] + #del self.block_tables[seq.seq_id] def free_cross(self, seq_group: SequenceGroup) -> None: if seq_group.request_id not in self.cross_block_tables: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index affd5ae9d8023..e67f1d71bc809 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -681,12 +681,13 @@ def _process_model_outputs( """ now = time.time() - # Organize outputs by [sequence group][step] instead of # [step][sequence group]. output_by_sequence_group = create_output_by_sequence_group( output, num_seq_groups=len(scheduled_seq_groups)) + seq_groups = [scheduled_seq_group.seq_group for scheduled_seq_group in scheduled_seq_groups] + # Update the scheduled sequence groups with the model outputs. for scheduled_seq_group, outputs, seq_group_meta in zip( scheduled_seq_groups, output_by_sequence_group, @@ -708,14 +709,17 @@ def _process_model_outputs( # Create the outputs. request_outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] - for scheduled_seq_group in scheduled_seq_groups: - seq_group = scheduled_seq_group.seq_group - seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create(seq_group) - request_outputs.append(request_output) - for seq_group in ignored_seq_groups: - request_output = RequestOutputFactory.create(seq_group) - request_outputs.append(request_output) + [seq_group.maybe_set_first_token_time(now) for seq_group in seq_groups] + request_outputs = [RequestOutputFactory.create(seq_group) for seq_group in seq_groups] + #for scheduled_seq_group in scheduled_seq_groups: + # seq_group = scheduled_seq_group.seq_group + # seq_group.maybe_set_first_token_time(now) + # request_output = RequestOutputFactory.create(seq_group) + # request_outputs.append(request_output) + request_outputs.extend([RequestOutputFactory.create(seq_group) for seq_group in ignored_seq_groups]) + #for seq_group in ignored_seq_groups: + # request_output = RequestOutputFactory.create(seq_group) + # request_outputs.append(request_output) return request_outputs def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 44de1d7ec5607..fc8885a69652c 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -74,6 +74,30 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, outputs: SequenceGroupOutput) -> None: # Process samples samples = outputs.samples + if len(samples)==1: + #if there's only 1 sample, it has to be from 1 running seq in seq group + parent_seq = next(iter(seq_group.seqs_dict.values())) + child_sample = samples[0] + if not seq_group.sampling_params.use_beam_search: + #fastpath + parent_seq.append_token_id(child_sample.output_token, + child_sample.logprobs) + if self.detokenizer and seq_group.sampling_params.detokenize: + new_char_count = self.detokenizer.decode_sequence_inplace( + parent_seq, seq_group.sampling_params) + else: + new_char_count = 0 + + stopped = self.stop_checker.maybe_stop_sequence( + parent_seq, + new_char_count, + seq_group.sampling_params, + lora_req=seq_group.lora_request, + ) + #if parent_seq.is_finished(): + if stopped: + self.scheduler.free_seq(parent_seq) + return parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) existing_finished_seqs = seq_group.get_finished_seqs() parent_child_dict: Dict[int, List[SequenceOutput]] = { diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 96f0d1142611b..eda6ca16ac5af 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -33,7 +33,7 @@ def maybe_stop_sequence( new_char_count: int, sampling_params: SamplingParams, lora_req: Optional[LoRARequest] = None, - ) -> None: + ) -> bool: """Stop the finished sequences. new_char_count is the number of chars added to the @@ -42,23 +42,24 @@ def maybe_stop_sequence( # Check if the minimum number of tokens has been generated yet; # skip the stop string/token checks if not - if seq.get_output_len() < sampling_params.min_tokens: - return + outlen = seq.get_output_len() + if outlen < sampling_params.min_tokens: + return False + last_token_id = seq.get_last_token_id() # Check if the sequence has generated the EOS token. if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == seq.eos_token_id): + and last_token_id == seq.eos_token_id): # Remove the last EOS token unless explicitly specified # This prevents unintended exposure of the EOS token if new_char_count and ( not sampling_params.include_stop_str_in_output): seq.output_text = seq.output_text[:-new_char_count] seq.status = SequenceStatus.FINISHED_STOPPED - return + return True # Check if a stop token was encountered. # This assumes a single token produced per step. - last_token_id = seq.get_last_token_id() if last_token_id in sampling_params.stop_token_ids: if new_char_count and ( not sampling_params.include_stop_str_in_output): @@ -66,7 +67,7 @@ def maybe_stop_sequence( seq.output_text = seq.output_text[:-new_char_count] seq.status = SequenceStatus.FINISHED_STOPPED seq.stop_reason = last_token_id - return + return True # Check if any stop strings are matched. stop_str = self._check_stop_strings(seq, new_char_count, @@ -74,17 +75,18 @@ def maybe_stop_sequence( if stop_str is not None: seq.status = SequenceStatus.FINISHED_STOPPED seq.stop_reason = stop_str - return + return True # Check if the sequence has reached max_model_len. if seq.get_len() > self._get_max_model_len(lora_req): seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return + return True # Check if the sequence has reached max_tokens. - if seq.get_output_len() == sampling_params.max_tokens: + if outlen == sampling_params.max_tokens: seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return + return True + return False @staticmethod def _check_stop_strings(seq: Sequence, new_char_count: int, diff --git a/vllm/sequence.py b/vllm/sequence.py index 9dcef7f804157..522355adc8f64 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -166,6 +166,8 @@ def update_num_computed_tokens(self, num_new_computed_tokens: int): self._num_computed_tokens += num_new_computed_tokens assert self._num_computed_tokens <= self.get_len(), ( self._num_computed_tokens, self.get_len()) + if self._stage == SequenceStage.DECODE: + return # If all tokens are computed, it means it is in decoding phase. if self.get_num_uncomputed_tokens() == 0: self._stage = SequenceStage.DECODE