Skip to content

Commit

Permalink
tp
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed Dec 5, 2024
1 parent 37afbb6 commit dcc6e85
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,7 @@ async def score_proposal(self, inputs: ModelInputs, swap_in_map: SwapMap,
hidden_states = hidden_states.reshape(
[-1, num_speculative_tokens + 1, hidden_states.shape[-1]])
logits = self.get_logits(hidden_states)
self.stream.synchronize()
return logits

def forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
Expand Down Expand Up @@ -908,6 +909,36 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
self.stream.synchronize)
return output

async def tree_decoding(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap,
retrieve_indices: torch.Tensor):
bs = inputs.history_lengths.shape[0]
inputs.medusa_position_ids = inputs.medusa_position_ids.repeat(
inputs.history_lengths.shape[0], 1)
inputs.medusa_position_ids = inputs.medusa_position_ids.to(
inputs.history_lengths.device) + inputs.history_lengths[:, None]
with get_dist_manager().context(self._dist_ctx):
self.mp_bar.wait()
rank = 0
_broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map],
self.stream)
cache_swapping(self.cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
spec_outputs = model_forward(
self.patched_model,
inputs,
self.cache_engine,
world_size=1,
stream=self.stream,
)
hidden_states = spec_outputs['hidden_states']
hidden_states = hidden_states.reshape(bs, -1,
hidden_states.shape[-1])
logits = self.get_logits(hidden_states)[:, retrieve_indices]
self.stream.synchronize()
return logits

def generate_candidates(self, draft_logits: torch.Tensor,
base_token_id: torch.Tensor):
return self.speculative_model.generate_candidates(
Expand Down

0 comments on commit dcc6e85

Please sign in to comment.