-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add belu score evaluator * refine * add * add t5layernorm * refine * updata tokenizer * refine format * refine * add functions for generation * test greedy serch * support model parallel * test t5 * add export onnx script * refine cpu to cuda * fix bug * add logits warper * add logits warper * add multinomial_sample * add beam_search utils * reformat * add _reorder_cache function * add def prepare_inputs_for_generation( * restructure MT5 for inference * reformat * add beam_serch * update text_generation.py * update generation * add some logits processor and generation function * reformat * restructure mt5 * add mt5 model test * refine stopping criteria * add utils function * reformat * refine utils function * update t5model, update t5loader, add t5 loader test * refine * replace multinomial * replace nansum * reformat * update multinomial sampling * update beam search * add copyright * reformat * refine * reformat * add comments * add comments * finish Mt5 single node single gpu generate * simplify cfg * reformat * create generator dir * update * finish Mt5 pipeline generate * finish Mt5 tensor_pipeline parallel generate * change convert_to_onnx to convert_to_onnx_and_check * fix multi sentence input inference * modify return type * reformat * merge mt5 branch * succed run onnx export and check * add onnx inference script * reformat code Co-authored-by: xiezipeng-ML <[email protected]> Co-authored-by: Zipeng Xie <[email protected]>
- Loading branch information
1 parent
6de6064
commit e9ca408
Showing
2 changed files
with
196 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# coding=utf-8 | ||
# Copyright 2021 The OneFlow Authors. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from collections import OrderedDict | ||
from typing import List | ||
|
||
import numpy as np | ||
import onnxruntime as ort | ||
|
||
|
||
class OnnxModel: | ||
def __init__( | ||
self, | ||
onnx_filename, | ||
providers: List[str] = None, | ||
ort_optimize: bool = True, | ||
): | ||
ort_sess_opt = ort.SessionOptions() | ||
ort_sess_opt.graph_optimization_level = ( | ||
ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED | ||
if ort_optimize | ||
else ort.GraphOptimizationLevel.ORT_DISABLE_ALL | ||
) | ||
if providers is None: | ||
if ort.__version__ > "1.9.0": | ||
providers = [ | ||
"TensorrtExecutionProvider", | ||
"CUDAExecutionProvider", | ||
"CPUExecutionProvider", | ||
] | ||
else: | ||
providers = ["CPUExecutionProvider"] | ||
self.sess = ort.InferenceSession( | ||
onnx_filename, sess_options=ort_sess_opt, providers=providers | ||
) | ||
|
||
def forward(self, input_list): | ||
ipt_dict = OrderedDict() | ||
for idx, ipt in enumerate(self.sess.get_inputs()): | ||
ipt_dict[ipt.name] = input_list[idx] | ||
onnx_res = self.sess.run([], ipt_dict) | ||
return onnx_res | ||
|
||
|
||
if __name__ == "__main__": | ||
onnx_model = OnnxModel("model.onnx") | ||
input_list = [ | ||
np.ones((1, 5)).astype(np.int64), | ||
np.ones((1, 3)).astype(np.int64), | ||
np.ones((1, 5, 5)).astype(bool), | ||
np.ones((1, 3, 3)).astype(bool), | ||
np.ones((1, 3, 5)).astype(bool), | ||
] | ||
|
||
print(onnx_model.forward(input_list)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# coding=utf-8 | ||
# Copyright 2021 The OneFlow Authors. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import oneflow as flow | ||
from oneflow import nn | ||
from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check | ||
|
||
from libai.config import LazyConfig | ||
from libai.engine import DefaultTrainer | ||
|
||
|
||
def get_model(config_file): | ||
cfg = LazyConfig.load(config_file) | ||
|
||
cfg.model.cfg.mlp_type = "t5" | ||
cfg.model.cfg.pretrained_model_path = None | ||
cfg.dataloader = None | ||
cfg.tokenization = None | ||
|
||
print("Building model....") | ||
model = DefaultTrainer.build_model(cfg) | ||
print("Build model finished.") | ||
|
||
return model | ||
|
||
|
||
class t5Graph(nn.Graph): | ||
def __init__(self, eager_model): | ||
super().__init__() | ||
self.model = eager_model | ||
|
||
def build( | ||
self, | ||
encoder_input_ids, | ||
encoder_attn_mask, | ||
decoder_input_ids, | ||
decoder_attn_mask, | ||
encoder_decoder_attn_mask, | ||
): | ||
out = self.model( | ||
encoder_input_ids, | ||
encoder_attn_mask, | ||
decoder_input_ids, | ||
decoder_attn_mask, | ||
encoder_decoder_attn_mask, | ||
) | ||
return out["prediction_scores"] | ||
|
||
|
||
if __name__ == "__main__": | ||
model = get_model("projects/MT5/configs/mt5_pretrain.py") | ||
model.eval() | ||
|
||
t5_graph = t5Graph(model) | ||
# Build the static graph model | ||
encoder_input_ids = flow.ones( | ||
1, 5, dtype=flow.int64, sbp=flow.sbp.broadcast, placement=flow.placement("cuda", ranks=[0]) | ||
) | ||
encoder_attn_mask = flow.ones( | ||
1, 3, dtype=flow.int64, sbp=flow.sbp.broadcast, placement=flow.placement("cuda", ranks=[0]) | ||
) | ||
decoder_input_ids = flow.ones( | ||
1, | ||
5, | ||
5, | ||
dtype=flow.bool, | ||
sbp=flow.sbp.broadcast, | ||
placement=flow.placement("cuda", ranks=[0]), | ||
) | ||
decoder_attn_mask = flow.ones( | ||
1, | ||
3, | ||
3, | ||
dtype=flow.bool, | ||
sbp=flow.sbp.broadcast, | ||
placement=flow.placement("cuda", ranks=[0]), | ||
) | ||
encoder_decoder_attn_mask = flow.ones( | ||
1, | ||
3, | ||
5, | ||
dtype=flow.bool, | ||
sbp=flow.sbp.broadcast, | ||
placement=flow.placement("cuda", ranks=[0]), | ||
) | ||
|
||
# check your model.forward is valid | ||
# output = t5_graph( | ||
# encoder_input_ids, | ||
# encoder_attn_mask, | ||
# decoder_input_ids, | ||
# decoder_attn_mask, | ||
# encoder_decoder_attn_mask | ||
# ) | ||
# print(output) | ||
|
||
print("Compiling the graph which may make some time, please wait for a moment....") | ||
t5_graph._compile( | ||
encoder_input_ids, | ||
encoder_attn_mask, | ||
decoder_input_ids, | ||
decoder_attn_mask, | ||
encoder_decoder_attn_mask, | ||
) | ||
|
||
convert_to_onnx_and_check( | ||
t5_graph, | ||
external_data=False, | ||
opset=11, | ||
flow_weight_dir=None, | ||
onnx_model_path="./", | ||
dynamic_batch_size=False, | ||
device="gpu_global", | ||
input_tensor_range=[0, 10], | ||
) |