-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add belu score evaluator * refine * add * add t5layernorm * refine * updata tokenizer * refine format * refine * add functions for generation * test greedy serch * support model parallel * test t5 * add logits warper * add logits warper * add multinomial_sample * add beam_search utils * reformat * add _reorder_cache function * add def prepare_inputs_for_generation( * restructure MT5 for inference * reformat * add beam_serch * update text_generation.py * update generation * add some logits processor and generation function * reformat * restructure mt5 * add mt5 model test * refine stopping criteria * add utils function * reformat * refine utils function * update t5model, update t5loader, add t5 loader test * refine * replace multinomial * replace nansum * reformat * update multinomial sampling * update beam search * add copyright * reformat * refine * reformat * add comments * add comments * finish Mt5 single node single gpu generate * simplify cfg * reformat * create generator dir * update * finish Mt5 pipeline generate * finish Mt5 tensor_pipeline parallel generate * fix multi sentence input inference * reformat * refine * refine mt5_loader Co-authored-by: CPFLAME <[email protected]>
- Loading branch information
1 parent
e511855
commit 6de6064
Showing
17 changed files
with
2,275 additions
and
194 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,373 @@ | ||
# coding=utf-8 | ||
# Copyright 2021 The OneFlow Authors. All rights reserved. | ||
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and | ||
# The HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import warnings | ||
from abc import ABC, abstractmethod | ||
from collections import UserDict | ||
from typing import Optional, Tuple | ||
|
||
import oneflow as flow | ||
|
||
from libai.utils import distributed as dist | ||
|
||
|
||
class BeamScorer(ABC): | ||
@abstractmethod | ||
def process( | ||
self, | ||
input_ids: flow.Tensor, | ||
next_scores: flow.Tensor, | ||
next_tokens: flow.Tensor, | ||
next_indices: flow.Tensor, | ||
**kwargs, | ||
): | ||
raise NotImplementedError("This is an abstract method.") | ||
|
||
|
||
class BeamHypotheses: | ||
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool): | ||
""" | ||
Initialize n-best list of hypotheses. | ||
""" | ||
self.length_penalty = length_penalty | ||
self.early_stopping = early_stopping | ||
self.num_beams = num_beams | ||
self.beams = [] | ||
self.worst_score = 1e9 | ||
|
||
def __len__(self) -> int: | ||
""" | ||
Number of hypotheses in the list. | ||
""" | ||
return len(self.beams) | ||
|
||
def add( | ||
self, hyp: flow.Tensor, sum_logprobs: float, beam_indices: Optional[flow.Tensor] = None | ||
): | ||
""" | ||
Add a new hypothesis to the list. | ||
""" | ||
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty) | ||
if len(self) < self.num_beams or score > self.worst_score: | ||
self.beams.append((score, hyp, beam_indices)) | ||
if len(self) > self.num_beams: | ||
sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)]) | ||
del self.beams[sorted_next_scores[0][1]] | ||
self.worst_score = sorted_next_scores[1][0] | ||
else: | ||
self.worst_score = min(score, self.worst_score) | ||
|
||
def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: | ||
""" | ||
If there are enough hypotheses and that none of the hypotheses being generated | ||
can become better than the worst one in the heap, then we are done with this sentence. | ||
""" | ||
|
||
if len(self) < self.num_beams: | ||
return False | ||
elif self.early_stopping: | ||
return True | ||
else: | ||
cur_score = best_sum_logprobs / cur_len ** self.length_penalty | ||
ret = self.worst_score >= cur_score | ||
return ret | ||
|
||
|
||
class BeamSearchScorer(BeamScorer): | ||
def __init__( | ||
self, | ||
batch_size: int, | ||
num_beams: int, | ||
length_penalty: Optional[float] = 1.0, | ||
do_early_stopping: Optional[bool] = False, | ||
num_beam_hyps_to_keep: Optional[int] = 1, | ||
num_beam_groups: Optional[int] = 1, | ||
**kwargs, | ||
): | ||
self.num_beams = num_beams | ||
self.length_penalty = length_penalty | ||
self.do_early_stopping = do_early_stopping | ||
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep | ||
self.num_beam_groups = num_beam_groups | ||
self.group_size = self.num_beams // self.num_beam_groups | ||
|
||
self._is_init = False | ||
self._beam_hyps = [ | ||
BeamHypotheses( | ||
num_beams=self.num_beams, | ||
length_penalty=self.length_penalty, | ||
early_stopping=self.do_early_stopping, | ||
) | ||
for _ in range(batch_size) | ||
] | ||
|
||
self._done = flow.tensor( | ||
[False for _ in range(batch_size)], | ||
dtype=flow.bool, | ||
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), | ||
placement=flow.placement("cuda", list(range(dist.get_world_size()))), | ||
) | ||
|
||
if not isinstance(num_beams, int) or num_beams <= 1: | ||
raise ValueError( | ||
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}." | ||
"For `num_beams` == 1, one should make use of `greedy_search` instead." | ||
) | ||
|
||
if ( | ||
not isinstance(num_beam_groups, int) | ||
or (num_beam_groups > num_beams) | ||
or (num_beams % num_beam_groups != 0) | ||
): | ||
raise ValueError( | ||
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and " | ||
f"`num_beams` has to be divisible by `num_beam_groups`, but is {num_beam_groups}" | ||
f"with `num_beams` being {num_beams}." | ||
) | ||
|
||
if "max_length" in kwargs: | ||
warnings.warn( | ||
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect. " | ||
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" | ||
", or `group_beam_search(...)`." | ||
) | ||
|
||
@property | ||
def is_done(self) -> bool: | ||
return self._done.all() | ||
|
||
def process( | ||
self, | ||
input_ids: flow.Tensor, | ||
next_scores: flow.Tensor, | ||
next_tokens: flow.Tensor, | ||
next_indices: flow.Tensor, | ||
pad_token_id: Optional[int] = None, | ||
eos_token_id: Optional[int] = None, | ||
beam_indices: Optional[flow.Tensor] = None, | ||
) -> Tuple[flow.Tensor]: | ||
cur_len = input_ids.shape[-1] | ||
batch_size = len(self._beam_hyps) | ||
if not (batch_size == (input_ids.shape[0] // self.group_size)): | ||
if self.num_beam_groups > 1: | ||
raise ValueError( | ||
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group " | ||
f"beam size of {self.group_size} is expected by the beam scorer." | ||
) | ||
else: | ||
raise ValueError( | ||
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of " | ||
f"{self.group_size} is expected by the beam scorer." | ||
) | ||
next_beam_scores = flow.zeros( | ||
(batch_size, self.group_size), | ||
dtype=next_scores.dtype, | ||
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), | ||
placement=flow.placement("cuda", list(range(dist.get_world_size()))), | ||
) | ||
next_beam_tokens = flow.zeros( | ||
(batch_size, self.group_size), | ||
dtype=next_tokens.dtype, | ||
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), | ||
placement=flow.placement("cuda", list(range(dist.get_world_size()))), | ||
) | ||
next_beam_indices = flow.zeros( | ||
(batch_size, self.group_size), | ||
dtype=next_indices.dtype, | ||
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), | ||
placement=flow.placement("cuda", list(range(dist.get_world_size()))), | ||
) | ||
|
||
for batch_idx, beam_hyp in enumerate(self._beam_hyps): | ||
if self._done[batch_idx]: | ||
if self.num_beams < len(beam_hyp): | ||
raise ValueError( | ||
f"Batch can only be done if at least {self.num_beams} beams have " | ||
"been generated" | ||
) | ||
if eos_token_id is None or pad_token_id is None: | ||
raise ValueError( | ||
"Generated beams >= num_beams -> eos_token_id and pad_token have " | ||
"to be defined" | ||
) | ||
# pad the batch | ||
next_beam_scores[batch_idx, :] = 0 | ||
next_beam_tokens[batch_idx, :] = pad_token_id | ||
next_beam_indices[batch_idx, :] = 0 | ||
continue | ||
|
||
# next tokens for this sentence | ||
beam_idx = 0 | ||
for beam_token_rank, (next_token, next_score, next_index) in enumerate( | ||
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) | ||
): | ||
batch_beam_idx = batch_idx * self.group_size + next_index | ||
# add to generated hypotheses if end of sentence | ||
if (eos_token_id is not None) and (next_token.item() == eos_token_id): | ||
# if beam_token does not belong to top num_beams tokens, it should not be added | ||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size | ||
if is_beam_token_worse_than_top_num_beams: | ||
continue | ||
if beam_indices is not None: | ||
beam_index = beam_indices[batch_beam_idx] | ||
beam_index = beam_index + (next_index,) | ||
else: | ||
beam_index = None | ||
|
||
beam_hyp.add( | ||
input_ids[batch_beam_idx].clone(), | ||
next_score.item(), | ||
beam_indices=beam_index, | ||
) | ||
else: | ||
# add next predicted token since it is not eos_token | ||
next_beam_scores[batch_idx, beam_idx] = next_score | ||
next_beam_tokens[batch_idx, beam_idx] = next_token | ||
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx | ||
beam_idx += 1 | ||
|
||
# once the beam for next step is full, don't add more tokens to it. | ||
if beam_idx == self.group_size: | ||
break | ||
|
||
if beam_idx < self.group_size: | ||
raise ValueError( | ||
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal " | ||
f"to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} " | ||
"are corrected." | ||
) | ||
|
||
# Check if we are done so that we can save a pad step if all(done) | ||
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( | ||
next_scores[batch_idx].max().item(), cur_len | ||
) | ||
|
||
return UserDict( | ||
{ | ||
"next_beam_scores": next_beam_scores.view(-1), | ||
"next_beam_tokens": next_beam_tokens.view(-1), | ||
"next_beam_indices": next_beam_indices.view(-1), | ||
} | ||
) | ||
|
||
def finalize( | ||
self, | ||
input_ids: flow.Tensor, | ||
final_beam_scores: flow.Tensor, | ||
final_beam_tokens: flow.Tensor, | ||
final_beam_indices: flow.Tensor, | ||
max_length: int, | ||
pad_token_id: Optional[int] = None, | ||
eos_token_id: Optional[int] = None, | ||
beam_indices: Optional[flow.Tensor] = None, | ||
): | ||
batch_size = len(self._beam_hyps) | ||
# finalize all open beam hypotheses and add to generated hypotheses | ||
for batch_idx, beam_hyp in enumerate(self._beam_hyps): | ||
if self._done[batch_idx]: | ||
continue | ||
|
||
# all open beam hypotheses are added to the beam hypothesis | ||
# beam hypothesis class automatically keeps the best beams | ||
for beam_id in range(self.num_beams): | ||
batch_beam_idx = batch_idx * self.num_beams + beam_id | ||
final_score = final_beam_scores[batch_beam_idx].item() | ||
final_tokens = input_ids[batch_beam_idx] | ||
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None | ||
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index) | ||
|
||
# select the best hypotheses | ||
sent_lengths = flow.zeros( | ||
batch_size * self.num_beam_hyps_to_keep, | ||
dtype=flow.long, | ||
sbp=input_ids.sbp, | ||
placement=input_ids.placement, | ||
) | ||
best = [] | ||
best_indices = [] | ||
best_scores = flow.zeros( | ||
batch_size * self.num_beam_hyps_to_keep, | ||
dtype=flow.float32, | ||
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), | ||
placement=flow.placement("cuda", list(range(dist.get_world_size()))), | ||
) | ||
|
||
# retrieve best hypotheses | ||
for i, beam_hyp in enumerate(self._beam_hyps): | ||
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) | ||
for j in range(self.num_beam_hyps_to_keep): | ||
best_hyp_tuple = sorted_hyps.pop() | ||
best_score = best_hyp_tuple[0] | ||
best_hyp = best_hyp_tuple[1] | ||
best_index = best_hyp_tuple[2] | ||
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) | ||
|
||
# append hyp to lists | ||
best.append(best_hyp) | ||
|
||
# append indices to list | ||
best_indices.append(best_index) | ||
|
||
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score | ||
|
||
# prepare for adding eos | ||
sent_lengths_max = sent_lengths.max().item() + 1 | ||
sent_max_len = ( | ||
min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max | ||
) | ||
decoded = flow.zeros( | ||
(batch_size * self.num_beam_hyps_to_keep, sent_max_len), | ||
dtype=flow.long, | ||
sbp=input_ids.sbp, | ||
placement=input_ids.placement, | ||
) | ||
|
||
if len(best_indices) > 0 and best_indices[0] is not None: | ||
indices = flow.zeros( | ||
(batch_size * self.num_beam_hyps_to_keep, sent_max_len), | ||
dtype=flow.long, | ||
sbp=input_ids.sbp, | ||
placement=input_ids.placement, | ||
) | ||
else: | ||
indices = None | ||
|
||
# shorter batches are padded if needed | ||
if sent_lengths.min().item() != sent_lengths.max().item(): | ||
assert pad_token_id is not None, "`pad_token_id` has to be defined" | ||
decoded.fill_(pad_token_id) | ||
|
||
if indices is not None: | ||
indices.fill_(-1) | ||
|
||
# fill with hypotheses and eos_token_id if the latter fits in | ||
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)): | ||
decoded[i, : sent_lengths[i]] = hypo | ||
|
||
if indices is not None: | ||
indices[i, : len(best_idx)] = flow.tensor(best_idx) | ||
|
||
if sent_lengths[i] < sent_max_len: | ||
decoded[i, sent_lengths[i]] = eos_token_id | ||
|
||
return UserDict( | ||
{ | ||
"sequences": decoded, | ||
"sequence_scores": best_scores, | ||
"beam_indices": indices, | ||
} | ||
) |
Oops, something went wrong.