Skip to content

Commit

Permalink
Add generations (#375)
Browse files Browse the repository at this point in the history
* add belu score evaluator

* refine

* add

* add t5layernorm

* refine

* updata tokenizer

* refine format

* refine

* add functions for generation

* test greedy serch

* support model parallel

* test t5

* add logits warper

* add logits warper

* add multinomial_sample

* add beam_search utils

* reformat

* add _reorder_cache function

* add     def prepare_inputs_for_generation(

* restructure MT5 for inference

* reformat

* add beam_serch

* update text_generation.py

* update generation

* add some logits processor and generation function

* reformat

* restructure mt5

* add mt5 model test

* refine stopping criteria

* add utils function

* reformat

* refine utils function

* update t5model, update t5loader, add t5 loader test

* refine

* replace multinomial

* replace nansum

* reformat

* update multinomial sampling

* update beam search

* add copyright

* reformat

* refine

* reformat

* add comments

* add comments

* finish Mt5 single node single gpu generate

* simplify cfg

* reformat

* create generator dir

* update

* finish Mt5 pipeline generate

* finish Mt5 tensor_pipeline parallel generate

* fix multi sentence input inference

* reformat

* refine

* refine mt5_loader

Co-authored-by: CPFLAME <[email protected]>
  • Loading branch information
xiezipeng-ML and CPFLAME authored Oct 13, 2022
1 parent e511855 commit 6de6064
Show file tree
Hide file tree
Showing 17 changed files with 2,275 additions and 194 deletions.
2 changes: 2 additions & 0 deletions dev/model_loader_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_mt5_loader.py

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_t5_loader.py

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_swin_loader.py

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_swinv2_loader.py
Expand Down
2 changes: 2 additions & 0 deletions dev/model_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/models/test_t5.py

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/models/test_mt5.py

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/models/test_vit.py

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/models/test_swin.py
Expand Down
373 changes: 373 additions & 0 deletions libai/inference/generator/generation_beam_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,373 @@
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and
# The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from abc import ABC, abstractmethod
from collections import UserDict
from typing import Optional, Tuple

import oneflow as flow

from libai.utils import distributed as dist


class BeamScorer(ABC):
@abstractmethod
def process(
self,
input_ids: flow.Tensor,
next_scores: flow.Tensor,
next_tokens: flow.Tensor,
next_indices: flow.Tensor,
**kwargs,
):
raise NotImplementedError("This is an abstract method.")


class BeamHypotheses:
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool):
"""
Initialize n-best list of hypotheses.
"""
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9

def __len__(self) -> int:
"""
Number of hypotheses in the list.
"""
return len(self.beams)

def add(
self, hyp: flow.Tensor, sum_logprobs: float, beam_indices: Optional[flow.Tensor] = None
):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp, beam_indices))
if len(self) > self.num_beams:
sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
del self.beams[sorted_next_scores[0][1]]
self.worst_score = sorted_next_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)

def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""

if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret


class BeamSearchScorer(BeamScorer):
def __init__(
self,
batch_size: int,
num_beams: int,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1,
**kwargs,
):
self.num_beams = num_beams
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
self.num_beam_groups = num_beam_groups
self.group_size = self.num_beams // self.num_beam_groups

self._is_init = False
self._beam_hyps = [
BeamHypotheses(
num_beams=self.num_beams,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
)
for _ in range(batch_size)
]

self._done = flow.tensor(
[False for _ in range(batch_size)],
dtype=flow.bool,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)

if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError(
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}."
"For `num_beams` == 1, one should make use of `greedy_search` instead."
)

if (
not isinstance(num_beam_groups, int)
or (num_beam_groups > num_beams)
or (num_beams % num_beam_groups != 0)
):
raise ValueError(
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and "
f"`num_beams` has to be divisible by `num_beam_groups`, but is {num_beam_groups}"
f"with `num_beams` being {num_beams}."
)

if "max_length" in kwargs:
warnings.warn(
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect. "
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
", or `group_beam_search(...)`."
)

@property
def is_done(self) -> bool:
return self._done.all()

def process(
self,
input_ids: flow.Tensor,
next_scores: flow.Tensor,
next_tokens: flow.Tensor,
next_indices: flow.Tensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
beam_indices: Optional[flow.Tensor] = None,
) -> Tuple[flow.Tensor]:
cur_len = input_ids.shape[-1]
batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1:
raise ValueError(
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group "
f"beam size of {self.group_size} is expected by the beam scorer."
)
else:
raise ValueError(
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
f"{self.group_size} is expected by the beam scorer."
)
next_beam_scores = flow.zeros(
(batch_size, self.group_size),
dtype=next_scores.dtype,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)
next_beam_tokens = flow.zeros(
(batch_size, self.group_size),
dtype=next_tokens.dtype,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)
next_beam_indices = flow.zeros(
(batch_size, self.group_size),
dtype=next_indices.dtype,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)

for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
if self.num_beams < len(beam_hyp):
raise ValueError(
f"Batch can only be done if at least {self.num_beams} beams have "
"been generated"
)
if eos_token_id is None or pad_token_id is None:
raise ValueError(
"Generated beams >= num_beams -> eos_token_id and pad_token have "
"to be defined"
)
# pad the batch
next_beam_scores[batch_idx, :] = 0
next_beam_tokens[batch_idx, :] = pad_token_id
next_beam_indices[batch_idx, :] = 0
continue

# next tokens for this sentence
beam_idx = 0
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
):
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams:
continue
if beam_indices is not None:
beam_index = beam_indices[batch_beam_idx]
beam_index = beam_index + (next_index,)
else:
beam_index = None

beam_hyp.add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
)
else:
# add next predicted token since it is not eos_token
next_beam_scores[batch_idx, beam_idx] = next_score
next_beam_tokens[batch_idx, beam_idx] = next_token
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
beam_idx += 1

# once the beam for next step is full, don't add more tokens to it.
if beam_idx == self.group_size:
break

if beam_idx < self.group_size:
raise ValueError(
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal "
f"to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} "
"are corrected."
)

# Check if we are done so that we can save a pad step if all(done)
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len
)

return UserDict(
{
"next_beam_scores": next_beam_scores.view(-1),
"next_beam_tokens": next_beam_tokens.view(-1),
"next_beam_indices": next_beam_indices.view(-1),
}
)

def finalize(
self,
input_ids: flow.Tensor,
final_beam_scores: flow.Tensor,
final_beam_tokens: flow.Tensor,
final_beam_indices: flow.Tensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
beam_indices: Optional[flow.Tensor] = None,
):
batch_size = len(self._beam_hyps)
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
continue

# all open beam hypotheses are added to the beam hypothesis
# beam hypothesis class automatically keeps the best beams
for beam_id in range(self.num_beams):
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)

# select the best hypotheses
sent_lengths = flow.zeros(
batch_size * self.num_beam_hyps_to_keep,
dtype=flow.long,
sbp=input_ids.sbp,
placement=input_ids.placement,
)
best = []
best_indices = []
best_scores = flow.zeros(
batch_size * self.num_beam_hyps_to_keep,
dtype=flow.float32,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)

# retrieve best hypotheses
for i, beam_hyp in enumerate(self._beam_hyps):
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
for j in range(self.num_beam_hyps_to_keep):
best_hyp_tuple = sorted_hyps.pop()
best_score = best_hyp_tuple[0]
best_hyp = best_hyp_tuple[1]
best_index = best_hyp_tuple[2]
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)

# append hyp to lists
best.append(best_hyp)

# append indices to list
best_indices.append(best_index)

best_scores[i * self.num_beam_hyps_to_keep + j] = best_score

# prepare for adding eos
sent_lengths_max = sent_lengths.max().item() + 1
sent_max_len = (
min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
)
decoded = flow.zeros(
(batch_size * self.num_beam_hyps_to_keep, sent_max_len),
dtype=flow.long,
sbp=input_ids.sbp,
placement=input_ids.placement,
)

if len(best_indices) > 0 and best_indices[0] is not None:
indices = flow.zeros(
(batch_size * self.num_beam_hyps_to_keep, sent_max_len),
dtype=flow.long,
sbp=input_ids.sbp,
placement=input_ids.placement,
)
else:
indices = None

# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`pad_token_id` has to be defined"
decoded.fill_(pad_token_id)

if indices is not None:
indices.fill_(-1)

# fill with hypotheses and eos_token_id if the latter fits in
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
decoded[i, : sent_lengths[i]] = hypo

if indices is not None:
indices[i, : len(best_idx)] = flow.tensor(best_idx)

if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id

return UserDict(
{
"sequences": decoded,
"sequence_scores": best_scores,
"beam_indices": indices,
}
)
Loading

0 comments on commit 6de6064

Please sign in to comment.