Skip to content

Commit 414566a

Browse files
committed
update
1 parent 9d5528d commit 414566a

File tree

660 files changed

+339330
-206
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

660 files changed

+339330
-206
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
git+https://github.com/lonePatient/TorchBlocks.git
2-
transformers>=4.5.0
2+
transformers>=4.12.5
33
bert4keras
44
rjieba
55
jieba

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
name="roformer",
55
package_dir={"": "src"},
66
packages=find_packages("src"),
7-
version="0.2.2",
7+
version="0.3.0",
88
license="Apache 2.0",
99
description="roformer_pytorch",
1010
author="Jun Yu",
1111
author_email="[email protected]",
1212
url="https://github.com/JunnYu/RoFormer_pytorch",
1313
keywords=["roformer", "pytorch", "tf2.0"],
14-
install_requires=["transformers==4.9.1", "jieba"],
14+
install_requires=["transformers>=4.12.5", "rjieba"],
1515
)

src/roformer/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
# limitations under the License.
1818
from typing import TYPE_CHECKING
1919

20-
from transformers.file_utils import (
20+
from .transformers.file_utils import (
2121
_LazyModule,
2222
is_tf_available,
2323
is_tokenizers_available,
2424
is_torch_available,
2525
)
2626

27+
2728
_import_structure = {
2829
"configuration_roformer": [
2930
"ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
@@ -50,6 +51,7 @@
5051
"load_tf_weights_in_roformer",
5152
]
5253

54+
5355
if is_tf_available():
5456
_import_structure["modeling_tf_roformer"] = [
5557
"TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -64,6 +66,7 @@
6466
"TFRoFormerPreTrainedModel",
6567
]
6668

69+
6770
if TYPE_CHECKING:
6871
from .configuration_roformer import (
6972
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
@@ -103,7 +106,10 @@
103106
TFRoFormerPreTrainedModel,
104107
)
105108

109+
106110
else:
107111
import sys
108112

109-
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
113+
sys.modules[__name__] = _LazyModule(
114+
__name__, globals()["__file__"], _import_structure
115+
)

src/roformer/configuration_roformer.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
# limitations under the License.
1515
""" RoFormer model configuration """
1616

17-
from transformers.configuration_utils import PretrainedConfig
18-
from transformers.utils import logging
17+
from .transformers.configuration_utils import PretrainedConfig
18+
from .transformers.utils import logging
1919

2020
logger = logging.get_logger(__name__)
2121

@@ -36,18 +36,16 @@ class RoFormerConfig(PretrainedConfig):
3636
instantiate an RoFormer model according to the specified arguments, defining the model architecture. Instantiating
3737
a configuration with the defaults will yield a similar configuration to that of the RoFormer
3838
`junnyu/roformer_chinese_base <https://huggingface.co/junnyu/roformer_chinese_base>`__ architecture.
39-
4039
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
4140
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
42-
43-
4441
Args:
4542
vocab_size (:obj:`int`, `optional`, defaults to 50000):
4643
Vocabulary size of the RoFormer model. Defines the number of different tokens that can be represented by
4744
the :obj:`inputs_ids` passed when calling :class:`~transformers.RoFormerModel` or
4845
:class:`~transformers.TFRoFormerModel`.
4946
embedding_size (:obj:`int`, `optional`, defaults to None):
50-
Dimensionality of the encoder layers and the pooler layer.
47+
Dimensionality of the encoder layers and the pooler layer. Defaults to the :obj:`hidden_size` if not
48+
provided.
5149
hidden_size (:obj:`int`, `optional`, defaults to 768):
5250
Dimension of the encoder layers and the pooler layer.
5351
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
@@ -78,19 +76,12 @@ class RoFormerConfig(PretrainedConfig):
7876
relevant if ``config.is_decoder=True``.
7977
rotary_value (:obj:`bool`, `optional`, defaults to :obj:`False`):
8078
Whether or not apply rotary position embeddings on value layer.
81-
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
82-
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
83-
8479
Example::
85-
8680
>>> from transformers import RoFormerModel, RoFormerConfig
87-
8881
>>> # Initializing a RoFormer junnyu/roformer_chinese_base style configuration
8982
>>> configuration = RoFormerConfig()
90-
9183
>>> # Initializing a model from the junnyu/roformer_chinese_base style configuration
9284
>>> model = RoFormerModel(configuration)
93-
9485
>>> # Accessing the model configuration
9586
>>> configuration = model.config
9687
"""
@@ -112,7 +103,6 @@ def __init__(
112103
initializer_range=0.02,
113104
layer_norm_eps=1e-12,
114105
pad_token_id=0,
115-
gradient_checkpointing=False,
116106
rotary_value=False,
117107
use_cache=True,
118108
**kwargs
@@ -132,6 +122,5 @@ def __init__(
132122
self.type_vocab_size = type_vocab_size
133123
self.initializer_range = initializer_range
134124
self.layer_norm_eps = layer_norm_eps
135-
self.gradient_checkpointing = gradient_checkpointing
136125
self.rotary_value = rotary_value
137126
self.use_cache = use_cache

src/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import argparse
1818

1919
import torch
20-
from transformers.utils import logging
20+
from .transformers.utils import logging
2121

2222
from roformer import RoFormerConfig, RoFormerForMaskedLM, load_tf_weights_in_roformer
2323

0 commit comments

Comments
 (0)