From 37afbb638671afb095cc24531da632f91bda9cff Mon Sep 17 00:00:00 2001 From: AllentDan Date: Thu, 5 Dec 2024 11:35:43 +0800 Subject: [PATCH] support bonus token id --- lmdeploy/pytorch/engine/engine.py | 24 +++++---- lmdeploy/pytorch/models/medusa.py | 87 ++++++++++++++++++++++++++----- 2 files changed, 88 insertions(+), 23 deletions(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 64381b8748..5bcff97442 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -810,10 +810,12 @@ def __update_inputs(next_token_ids): num_ignore_eos = num_ignore_eos - 1 if 'spec_logits' in output: spec_logits = output['spec_logits'] - cart_candidates, tree_candidates, medusa_attn_mask, medusa_position_ids, retrieve_indices = self.model_agent.generate_candidates( - spec_logits, next_token_ids) + (cart_candidates, tree_candidates, medusa_attn_mask, + medusa_position_ids, + retrieve_indices) = self.model_agent.generate_candidates( + spec_logits, next_token_ids) bs, _, tree_decode_len = tree_candidates.shape - spec_inputs = copy.deepcopy(inputs) + spec_inputs = inputs spec_inputs.input_ids = tree_candidates.flatten().unsqueeze(0) spec_inputs.history_lengths += spec_inputs.seq_length spec_inputs.seq_length = torch.ones_like( @@ -826,22 +828,22 @@ def __update_inputs(next_token_ids): swap_out_map=swap_out_map, retrieve_indices=retrieve_indices) # NOTE currently only greedy sampling supported + # besides, we used the bonus token id predicted during + # tree decoding while original Medusa did not proposal_len = cart_candidates.shape[-1] greedy_token_ids = logits.argmax(-1) posterior_mask = cart_candidates[..., 1:] == greedy_token_ids[ ..., :-1] accept_len, best_idx = torch.cumprod(posterior_mask, dim=-1).sum(-1).max(-1) - # accept_len = torch.where(accept_len==proposal_len-1, proposal_len, accept_len) - next_token_ids = cart_candidates[torch.arange(bs), best_idx] - # bonus_token_ids = greedy_token_ids[torch.arange(bs),best_idx,-1:] - # next_token_ids = torch.cat([best_candidates, bonus_token_ids], -1) + greedy_token_ids = greedy_token_ids[torch.arange(bs), best_idx] + next_token_ids = torch.cat( + [next_token_ids[:, None], greedy_token_ids], -1) mask_idx = torch.arange( - proposal_len, + proposal_len + 1, device=next_token_ids.device).expand_as(next_token_ids) - next_token_ids[mask_idx > accept_len[:, None]] = -1 - # next_token_ids = next_token_ids[...,:-1] # to be removed - num_appendable_ids = num_appendable_ids - accept_len - 1 + next_token_ids[mask_idx > (accept_len[:, None] + 1)] = -1 + num_appendable_ids = num_appendable_ids - accept_len - 2 # stopping criteria stopped, num_appendable_ids = self._batch_stopping_criteria( diff --git a/lmdeploy/pytorch/models/medusa.py b/lmdeploy/pytorch/models/medusa.py index e7ed5c8aae..bc9d086dc9 100644 --- a/lmdeploy/pytorch/models/medusa.py +++ b/lmdeploy/pytorch/models/medusa.py @@ -21,20 +21,71 @@ (0, 1, 2), (8, ), (0, 4, 0), (0, 2, 1), (1, 3), (0, 0, 7), (0, 0, 0, 2), (0, 0, 8), (1, 1, 0), (0, 1, 0, 0), (6, 0), (9, ), (0, 1, 3), (0, 0, 0, 3), (1, 0, 2), (0, 5, 0), - (3, 1), (0, 0, 2, 0), (7, 0), (1, 4)] # noqa + (3, 1), (0, 0, 2, 0), (7, 0), (1, 4)] + +vicuna_13b_stage2 = [(0, ), (0, 0), (1, ), (0, 0, 0), (0, 1), (1, 0), (2, ), + (0, 2), (0, 0, 1), (0, 1, 0), (3, ), (0, 3), (2, 0), + (0, 0, 2), (0, 0, 0, 0), (0, 4), (1, 0, 0), (1, 1), (4, ), + (0, 0, 3), (0, 5), (0, 2, 0), (5, ), (3, 0), (0, 1, 1), + (0, 6), (0, 0, 4), (0, 0, 0, 1), + (0, 7), (0, 0, 5), (1, 2), (0, 0, 1, 0), (0, 3, 0), + (1, 0, 1), (4, 0), (0, 0, 6), (0, 8), (2, 0, 0), (0, 9), + (6, ), (7, ), (2, 1), (5, 0), (0, 1, 2), (0, 0, 0, 2), + (8, ), (0, 4, 0), (0, 1, 0, 0), (0, 2, 1), (0, 0, 7), + (1, 1, 0), (1, 3), (0, 0, 2, 0), (9, ), (0, 0, 8), + (0, 5, 0), (0, 0, 0, 3), (0, 0, 9), (0, 1, 3), (1, 0, 2), + (0, 0, 1, 1), (3, 0, 0), (1, 0, 0, 0)] + +vicuna_33b_stage2 = [(0, ), (0, 0), (1, ), (0, 1), (0, 0, 0), (1, 0), (2, ), + (0, 2), (0, 0, 1), (0, 3), (3, ), + (0, 1, 0), (2, 0), (0, 4), (4, ), (0, 0, 2), (1, 1), + (1, 0, 0), (0, 5), (5, ), (0, 0, 0, 0), (0, 0, 3), (3, 0), + (0, 2, 0), (0, 6), (0, 1, 1), (6, ), (0, 0, 4), (0, 7), + (7, ), (1, 2), (4, 0), (8, ), (0, 3, 0), (0, 0, 5), + (0, 0, 0, 1), (0, 8), (2, 1), (0, 9), (1, 0, 1), + (2, 0, 0), (0, 0, 6), (5, 0), (0, 0, 1, 0), (1, 3), + (0, 1, 2), (0, 4, 0), (0, 0, 7), (0, 2, 1), (9, ), + (1, 1, 0), (0, 0, 0, 2), (6, 0), (0, 0, 8), (0, 1, 0, 0), + (7, 0), (0, 1, 3), (0, 5, 0), (1, 4), (0, 0, 9), (3, 1), + (1, 0, 2), (2, 2)] + +zephyr_stage2 = [(0, ), (0, 0), (1, ), (0, 1), (2, ), + (0, 0, 0), (1, 0), (0, 2), (3, ), (0, 3), (4, ), (2, 0), + (0, 0, 1), (0, 4), (5, ), (0, 5), (0, 1, 0), (1, 1), (6, ), + (0, 0, 2), (3, 0), (0, 6), (7, ), (0, 7), (0, 8), (0, 0, 3), + (1, 0, 0), (0, 9), (0, 2, 0), (1, 2), (4, 0), (8, ), (9, ), + (2, 1), (0, 1, 1), (0, 0, 4), (0, 0, 0, 0), (5, 0), (0, 3, 0), + (1, 3), (0, 0, 5), (0, 0, 6), (6, 0), (2, 0, 0), (1, 0, 1), + (0, 1, 2), (0, 4, 0), (1, 4), (3, 1), (2, 2), (0, 0, 7), + (7, 0), (0, 2, 1), (0, 0, 8), (0, 1, 3), (0, 5, 0), (1, 5), + (0, 0, 9), (1, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0), (4, 1), + (2, 3)] +mc_sim_7b_63 = [[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], + [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, + 6], [6], + [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], + [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], + [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], + [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], + [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], + [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], + [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], + [0, 7, 0]] + TOPK = 10 def pad_path(path, length, pad_value=-2): """Pad the given path list with a specific value up to a specified length. - Parameters: - - path (list): The original list that needs padding. - - length (int): The desired length of the padded list. - - pad_value (optional, default=-2): The value to use for padding. + Args: + path (list): The original list that needs padding. + length (int): The desired length of the padded list. + pad_value (optional, default=-2): The value to use for padding. Returns: - - list: A new list based on the original path but padded to the desired length. + list: A new list based on the original path but padded to the desired + length. Example: >>> pad_path([1,2,3], 5) @@ -127,6 +178,14 @@ def __init__(self, self.medusa_choices = None if 'vicuna-7b' in config.base_model_name_or_path: self.medusa_choices = vicuna_7b_stage2 + elif 'vicuna-13b' in config.base_model_name_or_path: + self.medusa_choices = vicuna_13b_stage2 + elif 'vicuna-33b' in config.base_model_name_or_path: + self.medusa_choices = vicuna_33b_stage2 + elif 'zephyr' in config.base_model_name_or_path: + self.medusa_choices = zephyr_stage2 + else: + self.medusa_choices = mc_sim_7b_63 self.generate_medusa_buffers(device=device) def generate_medusa_buffers(self, device: torch.dtype = None): @@ -155,7 +214,8 @@ def generate_medusa_buffers(self, device: torch.dtype = None): key=lambda x: (len(x), x)) medusa_len = len(sorted_medusa_choices) + 1 - # Initialize depth_counts to keep track of how many choices have a particular depth + # Initialize depth_counts to keep track of how many choices have a + # particular depth depth_counts = [] prev_depth = 0 for path in sorted_medusa_choices: @@ -248,20 +308,22 @@ def generate_candidates(self, medusa_logits: torch.Tensor, 1. Cartesian candidates derived from the combined original and Medusa logits. 2. Tree candidates mapped from the Cartesian candidates using tree indices. """ # noqa - # Greedy decoding: Select the most probable candidate from the original logits. - # here we only implement greedy decoding + # Greedy decoding: Select the most probable candidate from the original + # logits. here we only implement greedy decoding bs = medusa_logits.shape[0] candidates_logit = base_token_id.unsqueeze(-1) # Extract the TOPK candidates from the medusa logits. candidates_medusa_logits = torch.topk(medusa_logits, TOPK, dim=-1).indices - # Combine the selected candidate from the original logits with the topk medusa logits. + # Combine the selected candidate from the original logits with the + # topk medusa logits. candidates = torch.cat( [candidates_logit, candidates_medusa_logits.view(bs, -1)], dim=-1) - # Map the combined candidates to the tree indices to get tree candidates. + # Map the combined candidates to the tree indices to get tree + # candidates. tree_candidates = candidates[:, self.tree_indices] # Extend the tree candidates by appending a zero. @@ -278,7 +340,8 @@ def generate_candidates(self, medusa_logits: torch.Tensor, # Unsqueeze the tree candidates for dimension consistency. tree_candidates = tree_candidates.unsqueeze( 1) # bs, 1, len(self.medusa_choices) - return cart_candidates, tree_candidates, self.medusa_attn_mask, self.medusa_position_ids, self.retrieve_indices + return (cart_candidates, tree_candidates, self.medusa_attn_mask, + self.medusa_position_ids, self.retrieve_indices) def support_cuda_graph( self,