Skip to content

Commit 568a27f

Browse files
committed
同步最新代码
1 parent d0dfe71 commit 568a27f

13 files changed

+1029
-503
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,11 @@ max diff : tensor(3.5763e-06)
104104
base版本
105105
python compare_model.py
106106
bert4keras vs pytorch
107-
mean diff : tensor(4.3925e-07)
108-
max diff : tensor(7.6294e-06)
107+
mean diff : tensor(4.3340e-07)
108+
max diff : tensor(5.7220e-06)
109109
bert4keras vs tf2.0
110-
mean diff : tensor(3.4151e-07)
111-
max diff : tensor(3.8147e-06)
110+
mean diff : tensor(3.4319e-07)
111+
max diff : tensor(5.2452e-06)
112112
```
113113

114114

convert_roformer_original_tf_checkpoint_to_pytorch.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,18 @@
1414
# limitations under the License.
1515
"""Convert RoFormer checkpoint."""
1616

17-
1817
import argparse
1918

2019
import torch
2120

2221
from roformer import RoFormerConfig, RoFormerForMaskedLM, load_tf_weights_in_roformer
2322
from transformers.utils import logging
2423

25-
2624
logging.set_verbosity_info()
2725

2826

29-
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
27+
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file,
28+
pytorch_dump_path):
3029
# Initialise PyTorch model
3130
config = RoFormerConfig.from_json_file(bert_config_file)
3231
print(f"Building PyTorch model from configuration: {config}")
@@ -37,25 +36,34 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
3736

3837
# Save pytorch-model
3938
print(f"Save PyTorch model to {pytorch_dump_path}")
40-
torch.save(model.state_dict(), pytorch_dump_path, _use_new_zipfile_serialization=False)
39+
torch.save(model.state_dict(),
40+
pytorch_dump_path,
41+
_use_new_zipfile_serialization=False)
4142

4243

4344
if __name__ == "__main__":
4445
parser = argparse.ArgumentParser()
4546
# Required parameters
46-
parser.add_argument(
47-
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
48-
)
47+
parser.add_argument("--tf_checkpoint_path",
48+
default=None,
49+
type=str,
50+
required=True,
51+
help="Path to the TensorFlow checkpoint path.")
4952
parser.add_argument(
5053
"--bert_config_file",
5154
default=None,
5255
type=str,
5356
required=True,
54-
help="The config json file corresponding to the pre-trained BERT model. \n"
57+
help=
58+
"The config json file corresponding to the pre-trained BERT model. \n"
5559
"This specifies the model architecture.",
5660
)
57-
parser.add_argument(
58-
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
59-
)
61+
parser.add_argument("--pytorch_dump_path",
62+
default=None,
63+
type=str,
64+
required=True,
65+
help="Path to the output PyTorch model.")
6066
args = parser.parse_args()
61-
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
67+
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
68+
args.bert_config_file,
69+
args.pytorch_dump_path)

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
name='roformer',
55
package_dir={"": "src"},
66
packages=find_packages("src"),
7-
version='0.0.8',
7+
version='0.1.0',
88
license='Apache 2.0',
99
description='roformer_pytorch',
1010
author='Jun Yu',
1111
author_email='[email protected]',
1212
url='https://github.com/JunnYu/RoFormer_pytorch',
1313
keywords=['roformer', 'pytorch', 'tf2.0'],
14-
install_requires=['transformers>=4.5.0', 'jieba', 'rjieba'],
14+
install_requires=['transformers>=4.5.0', 'jieba'],
1515
)

src/roformer/__init__.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@
1919

2020
from transformers.file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_available, is_torch_available
2121

22-
2322
_import_structure = {
24-
"configuration_roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig"],
23+
"configuration_roformer":
24+
["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig"],
2525
"tokenization_roformer": ["RoFormerTokenizer"],
2626
}
2727

28+
if is_tokenizers_available():
29+
_import_structure["tokenization_roformer_fast"] = ["RoFormerTokenizerFast"]
30+
2831
if is_torch_available():
2932
_import_structure["modeling_roformer"] = [
3033
"ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -40,7 +43,6 @@
4043
"load_tf_weights_in_roformer",
4144
]
4245

43-
4446
if is_tf_available():
4547
_import_structure["modeling_tf_roformer"] = [
4648
"TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -55,11 +57,13 @@
5557
"TFRoFormerPreTrainedModel",
5658
]
5759

58-
5960
if TYPE_CHECKING:
6061
from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig
6162
from .tokenization_roformer import RoFormerTokenizer
6263

64+
if is_tokenizers_available():
65+
from .tokenization_roformer_fast import RoFormerTokenizerFast
66+
6367
if is_torch_available():
6468
from .modeling_roformer import (
6569
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
@@ -89,7 +93,6 @@
8993
TFRoFormerPreTrainedModel,
9094
)
9195

92-
9396
else:
9497
import importlib
9598
import os

src/roformer/configuration_roformer.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,21 @@
1717
from transformers.configuration_utils import PretrainedConfig
1818
from transformers.utils import logging
1919

20-
2120
logger = logging.get_logger(__name__)
2221

2322
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24-
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/config.json",
25-
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json"
23+
"junnyu/roformer_chinese_small":
24+
"https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/config.json",
25+
"junnyu/roformer_chinese_base":
26+
"https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json",
27+
"junnyu/roformer_chinese_char_small":
28+
"https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/config.json",
29+
"junnyu/roformer_chinese_char_base":
30+
"https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/config.json",
31+
"junnyu/roformer_small_discriminator":
32+
"https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/config.json",
33+
"junnyu/roformer_small_generator":
34+
"https://huggingface.co/junnyu/roformer_small_generator/resolve/main/config.json",
2635
# See all RoFormer models at https://huggingface.co/models?filter=roformer
2736
}
2837

@@ -43,7 +52,7 @@ class RoFormerConfig(PretrainedConfig):
4352
Vocabulary size of the RoFormer model. Defines the number of different tokens that can be represented by
4453
the :obj:`inputs_ids` passed when calling :class:`~transformers.RoFormerModel` or
4554
:class:`~transformers.TFRoFormerModel`.
46-
embedding_size (:obj:`int`, `optional`, defaults to 768):
55+
embedding_size (:obj:`int`, `optional`, defaults to None):
4756
Dimensionality of the encoder layers and the pooler layer.
4857
hidden_size (:obj:`int`, `optional`, defaults to 768):
4958
Dimension of the encoder layers and the pooler layer.
@@ -93,27 +102,25 @@ class RoFormerConfig(PretrainedConfig):
93102
"""
94103
model_type = "roformer"
95104

96-
def __init__(
97-
self,
98-
vocab_size=50000,
99-
embedding_size=None,
100-
hidden_size=768,
101-
num_hidden_layers=12,
102-
num_attention_heads=12,
103-
intermediate_size=3072,
104-
hidden_act="gelu",
105-
hidden_dropout_prob=0.1,
106-
attention_probs_dropout_prob=0.1,
107-
max_position_embeddings=1536,
108-
type_vocab_size=2,
109-
initializer_range=0.02,
110-
layer_norm_eps=1e-12,
111-
pad_token_id=0,
112-
gradient_checkpointing=False,
113-
rotary_value=False,
114-
use_cache=True,
115-
**kwargs
116-
):
105+
def __init__(self,
106+
vocab_size=50000,
107+
embedding_size=None,
108+
hidden_size=768,
109+
num_hidden_layers=12,
110+
num_attention_heads=12,
111+
intermediate_size=3072,
112+
hidden_act="gelu",
113+
hidden_dropout_prob=0.1,
114+
attention_probs_dropout_prob=0.1,
115+
max_position_embeddings=1536,
116+
type_vocab_size=2,
117+
initializer_range=0.02,
118+
layer_norm_eps=1e-12,
119+
pad_token_id=0,
120+
gradient_checkpointing=False,
121+
rotary_value=False,
122+
use_cache=True,
123+
**kwargs):
117124
super().__init__(pad_token_id=pad_token_id, **kwargs)
118125

119126
self.vocab_size = vocab_size

src/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,18 @@
1414
# limitations under the License.
1515
"""Convert RoFormer checkpoint."""
1616

17-
1817
import argparse
1918

2019
import torch
2120

2221
from roformer import RoFormerConfig, RoFormerForMaskedLM, load_tf_weights_in_roformer
2322
from transformers.utils import logging
2423

25-
2624
logging.set_verbosity_info()
2725

2826

29-
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
27+
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file,
28+
pytorch_dump_path):
3029
# Initialise PyTorch model
3130
config = RoFormerConfig.from_json_file(bert_config_file)
3231
print(f"Building PyTorch model from configuration: {config}")
@@ -37,25 +36,34 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
3736

3837
# Save pytorch-model
3938
print(f"Save PyTorch model to {pytorch_dump_path}")
40-
torch.save(model.state_dict(), pytorch_dump_path, _use_new_zipfile_serialization=False)
39+
torch.save(model.state_dict(),
40+
pytorch_dump_path,
41+
_use_new_zipfile_serialization=False)
4142

4243

4344
if __name__ == "__main__":
4445
parser = argparse.ArgumentParser()
4546
# Required parameters
46-
parser.add_argument(
47-
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
48-
)
47+
parser.add_argument("--tf_checkpoint_path",
48+
default=None,
49+
type=str,
50+
required=True,
51+
help="Path to the TensorFlow checkpoint path.")
4952
parser.add_argument(
5053
"--bert_config_file",
5154
default=None,
5255
type=str,
5356
required=True,
54-
help="The config json file corresponding to the pre-trained BERT model. \n"
57+
help=
58+
"The config json file corresponding to the pre-trained BERT model. \n"
5559
"This specifies the model architecture.",
5660
)
57-
parser.add_argument(
58-
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
59-
)
61+
parser.add_argument("--pytorch_dump_path",
62+
default=None,
63+
type=str,
64+
required=True,
65+
help="Path to the output PyTorch model.")
6066
args = parser.parse_args()
61-
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
67+
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
68+
args.bert_config_file,
69+
args.pytorch_dump_path)

0 commit comments

Comments
 (0)