Skip to content

Commit

Permalink
support bonus token id
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed Dec 5, 2024
1 parent 9930e61 commit 37afbb6
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 23 deletions.
24 changes: 13 additions & 11 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,10 +810,12 @@ def __update_inputs(next_token_ids):
num_ignore_eos = num_ignore_eos - 1
if 'spec_logits' in output:
spec_logits = output['spec_logits']
cart_candidates, tree_candidates, medusa_attn_mask, medusa_position_ids, retrieve_indices = self.model_agent.generate_candidates(
spec_logits, next_token_ids)
(cart_candidates, tree_candidates, medusa_attn_mask,
medusa_position_ids,
retrieve_indices) = self.model_agent.generate_candidates(
spec_logits, next_token_ids)
bs, _, tree_decode_len = tree_candidates.shape
spec_inputs = copy.deepcopy(inputs)
spec_inputs = inputs
spec_inputs.input_ids = tree_candidates.flatten().unsqueeze(0)
spec_inputs.history_lengths += spec_inputs.seq_length
spec_inputs.seq_length = torch.ones_like(
Expand All @@ -826,22 +828,22 @@ def __update_inputs(next_token_ids):
swap_out_map=swap_out_map,
retrieve_indices=retrieve_indices)
# NOTE currently only greedy sampling supported
# besides, we used the bonus token id predicted during
# tree decoding while original Medusa did not
proposal_len = cart_candidates.shape[-1]
greedy_token_ids = logits.argmax(-1)
posterior_mask = cart_candidates[..., 1:] == greedy_token_ids[
..., :-1]
accept_len, best_idx = torch.cumprod(posterior_mask,
dim=-1).sum(-1).max(-1)
# accept_len = torch.where(accept_len==proposal_len-1, proposal_len, accept_len)
next_token_ids = cart_candidates[torch.arange(bs), best_idx]
# bonus_token_ids = greedy_token_ids[torch.arange(bs),best_idx,-1:]
# next_token_ids = torch.cat([best_candidates, bonus_token_ids], -1)
greedy_token_ids = greedy_token_ids[torch.arange(bs), best_idx]
next_token_ids = torch.cat(
[next_token_ids[:, None], greedy_token_ids], -1)
mask_idx = torch.arange(
proposal_len,
proposal_len + 1,
device=next_token_ids.device).expand_as(next_token_ids)
next_token_ids[mask_idx > accept_len[:, None]] = -1
# next_token_ids = next_token_ids[...,:-1] # to be removed
num_appendable_ids = num_appendable_ids - accept_len - 1
next_token_ids[mask_idx > (accept_len[:, None] + 1)] = -1
num_appendable_ids = num_appendable_ids - accept_len - 2

# stopping criteria
stopped, num_appendable_ids = self._batch_stopping_criteria(
Expand Down
87 changes: 75 additions & 12 deletions lmdeploy/pytorch/models/medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,71 @@
(0, 1, 2), (8, ), (0, 4, 0), (0, 2, 1), (1, 3), (0, 0, 7),
(0, 0, 0, 2), (0, 0, 8), (1, 1, 0), (0, 1, 0, 0), (6, 0),
(9, ), (0, 1, 3), (0, 0, 0, 3), (1, 0, 2), (0, 5, 0),
(3, 1), (0, 0, 2, 0), (7, 0), (1, 4)] # noqa
(3, 1), (0, 0, 2, 0), (7, 0), (1, 4)]

vicuna_13b_stage2 = [(0, ), (0, 0), (1, ), (0, 0, 0), (0, 1), (1, 0), (2, ),
(0, 2), (0, 0, 1), (0, 1, 0), (3, ), (0, 3), (2, 0),
(0, 0, 2), (0, 0, 0, 0), (0, 4), (1, 0, 0), (1, 1), (4, ),
(0, 0, 3), (0, 5), (0, 2, 0), (5, ), (3, 0), (0, 1, 1),
(0, 6), (0, 0, 4), (0, 0, 0, 1),
(0, 7), (0, 0, 5), (1, 2), (0, 0, 1, 0), (0, 3, 0),
(1, 0, 1), (4, 0), (0, 0, 6), (0, 8), (2, 0, 0), (0, 9),
(6, ), (7, ), (2, 1), (5, 0), (0, 1, 2), (0, 0, 0, 2),
(8, ), (0, 4, 0), (0, 1, 0, 0), (0, 2, 1), (0, 0, 7),
(1, 1, 0), (1, 3), (0, 0, 2, 0), (9, ), (0, 0, 8),
(0, 5, 0), (0, 0, 0, 3), (0, 0, 9), (0, 1, 3), (1, 0, 2),
(0, 0, 1, 1), (3, 0, 0), (1, 0, 0, 0)]

vicuna_33b_stage2 = [(0, ), (0, 0), (1, ), (0, 1), (0, 0, 0), (1, 0), (2, ),
(0, 2), (0, 0, 1), (0, 3), (3, ),
(0, 1, 0), (2, 0), (0, 4), (4, ), (0, 0, 2), (1, 1),
(1, 0, 0), (0, 5), (5, ), (0, 0, 0, 0), (0, 0, 3), (3, 0),
(0, 2, 0), (0, 6), (0, 1, 1), (6, ), (0, 0, 4), (0, 7),
(7, ), (1, 2), (4, 0), (8, ), (0, 3, 0), (0, 0, 5),
(0, 0, 0, 1), (0, 8), (2, 1), (0, 9), (1, 0, 1),
(2, 0, 0), (0, 0, 6), (5, 0), (0, 0, 1, 0), (1, 3),
(0, 1, 2), (0, 4, 0), (0, 0, 7), (0, 2, 1), (9, ),
(1, 1, 0), (0, 0, 0, 2), (6, 0), (0, 0, 8), (0, 1, 0, 0),
(7, 0), (0, 1, 3), (0, 5, 0), (1, 4), (0, 0, 9), (3, 1),
(1, 0, 2), (2, 2)]

zephyr_stage2 = [(0, ), (0, 0), (1, ), (0, 1), (2, ),
(0, 0, 0), (1, 0), (0, 2), (3, ), (0, 3), (4, ), (2, 0),
(0, 0, 1), (0, 4), (5, ), (0, 5), (0, 1, 0), (1, 1), (6, ),
(0, 0, 2), (3, 0), (0, 6), (7, ), (0, 7), (0, 8), (0, 0, 3),
(1, 0, 0), (0, 9), (0, 2, 0), (1, 2), (4, 0), (8, ), (9, ),
(2, 1), (0, 1, 1), (0, 0, 4), (0, 0, 0, 0), (5, 0), (0, 3, 0),
(1, 3), (0, 0, 5), (0, 0, 6), (6, 0), (2, 0, 0), (1, 0, 1),
(0, 1, 2), (0, 4, 0), (1, 4), (3, 1), (2, 2), (0, 0, 7),
(7, 0), (0, 2, 1), (0, 0, 8), (0, 1, 3), (0, 5, 0), (1, 5),
(0, 0, 9), (1, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0), (4, 1),
(2, 3)]
mc_sim_7b_63 = [[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3],
[0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0,
6], [6],
[0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0],
[0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3],
[4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1],
[0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8],
[0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2],
[2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5],
[1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6],
[0, 7, 0]]

TOPK = 10


def pad_path(path, length, pad_value=-2):
"""Pad the given path list with a specific value up to a specified length.
Parameters:
- path (list): The original list that needs padding.
- length (int): The desired length of the padded list.
- pad_value (optional, default=-2): The value to use for padding.
Args:
path (list): The original list that needs padding.
length (int): The desired length of the padded list.
pad_value (optional, default=-2): The value to use for padding.
Returns:
- list: A new list based on the original path but padded to the desired length.
list: A new list based on the original path but padded to the desired
length.
Example:
>>> pad_path([1,2,3], 5)
Expand Down Expand Up @@ -127,6 +178,14 @@ def __init__(self,
self.medusa_choices = None
if 'vicuna-7b' in config.base_model_name_or_path:
self.medusa_choices = vicuna_7b_stage2
elif 'vicuna-13b' in config.base_model_name_or_path:
self.medusa_choices = vicuna_13b_stage2
elif 'vicuna-33b' in config.base_model_name_or_path:
self.medusa_choices = vicuna_33b_stage2
elif 'zephyr' in config.base_model_name_or_path:
self.medusa_choices = zephyr_stage2
else:
self.medusa_choices = mc_sim_7b_63
self.generate_medusa_buffers(device=device)

def generate_medusa_buffers(self, device: torch.dtype = None):
Expand Down Expand Up @@ -155,7 +214,8 @@ def generate_medusa_buffers(self, device: torch.dtype = None):
key=lambda x: (len(x), x))
medusa_len = len(sorted_medusa_choices) + 1

# Initialize depth_counts to keep track of how many choices have a particular depth
# Initialize depth_counts to keep track of how many choices have a
# particular depth
depth_counts = []
prev_depth = 0
for path in sorted_medusa_choices:
Expand Down Expand Up @@ -248,20 +308,22 @@ def generate_candidates(self, medusa_logits: torch.Tensor,
1. Cartesian candidates derived from the combined original and Medusa logits.
2. Tree candidates mapped from the Cartesian candidates using tree indices.
""" # noqa
# Greedy decoding: Select the most probable candidate from the original logits.
# here we only implement greedy decoding
# Greedy decoding: Select the most probable candidate from the original
# logits. here we only implement greedy decoding
bs = medusa_logits.shape[0]
candidates_logit = base_token_id.unsqueeze(-1)
# Extract the TOPK candidates from the medusa logits.
candidates_medusa_logits = torch.topk(medusa_logits, TOPK,
dim=-1).indices

# Combine the selected candidate from the original logits with the topk medusa logits.
# Combine the selected candidate from the original logits with the
# topk medusa logits.
candidates = torch.cat(
[candidates_logit,
candidates_medusa_logits.view(bs, -1)], dim=-1)

# Map the combined candidates to the tree indices to get tree candidates.
# Map the combined candidates to the tree indices to get tree
# candidates.
tree_candidates = candidates[:, self.tree_indices]

# Extend the tree candidates by appending a zero.
Expand All @@ -278,7 +340,8 @@ def generate_candidates(self, medusa_logits: torch.Tensor,
# Unsqueeze the tree candidates for dimension consistency.
tree_candidates = tree_candidates.unsqueeze(
1) # bs, 1, len(self.medusa_choices)
return cart_candidates, tree_candidates, self.medusa_attn_mask, self.medusa_position_ids, self.retrieve_indices
return (cart_candidates, tree_candidates, self.medusa_attn_mask,
self.medusa_position_ids, self.retrieve_indices)

def support_cuda_graph(
self,
Expand Down

0 comments on commit 37afbb6

Please sign in to comment.