Skip to content

Commit b8be240

Browse files
committed
add roformer-sim的例子,并更新rotary的实现方式
1 parent 0eb4fc7 commit b8be240

8 files changed

+376
-85
lines changed

README.md

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,94 @@
22
RoFormer模型和RoFormer-V2模型
33

44
## 更新
5-
- 2022/03/21 添加`roformer-v2`的权重, 注:必须使用本仓库的代码,不能使用transformers仓库的代码!!!
5+
- **2022/04/02**
6+
(1)修改RoFormerForCausalLM,支持`roformer-sim`并提供相关的例子,请见`examples/test_sim.py`
7+
(2)修改`apply_rotary`实现方式,看起来更简单。
8+
```python
9+
def apply_rotary(x, sinusoidal_pos):
10+
sin, cos = sinusoidal_pos
11+
x1, x2 = x[..., 0::2], x[..., 1::2]
12+
return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
13+
```
14+
- **2022/03/21** 添加`roformer-v2`的权重, 注:必须使用本仓库的代码,不能使用transformers仓库的代码!!!
615

7-
## v2版本安装
16+
17+
## 安装
818
```bash
19+
# v2版本
920
pip install roformer>=0.4.0
10-
# 如果安装不了,说明清华镜像源没有同步,过一会就可以安装。
21+
# v1版本(代码已经加入到huggingface仓库,请使用新版本的transformers)
22+
pip install -U transformers
1123
```
1224

13-
## v1版本安装(代码已经加入到huggingface仓库)
14-
transformers v4.7版本已经发布,可以直接安装使用
15-
```bash
16-
pip install -U transformers
25+
## roformer-sim测试例子
26+
```python
27+
import torch
28+
import numpy as np
29+
from roformer import RoFormerForCausalLM, RoFormerConfig
30+
from transformers import BertTokenizer
31+
32+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
33+
# 可选以下几个。
34+
# junnyu/roformer_chinese_sim_char_small, junnyu/roformer_chinese_sim_char_base
35+
# junnyu/roformer_chinese_sim_char_ft_small, roformer_chinese_sim_char_ft_base
36+
pretrained_model = "junnyu/roformer_chinese_sim_char_base"
37+
tokenizer = BertTokenizer.from_pretrained(pretrained_model)
38+
config = RoFormerConfig.from_pretrained(pretrained_model)
39+
config.is_decoder = True
40+
config.eos_token_id = tokenizer.sep_token_id
41+
config.pooler_activation = "linear"
42+
model = RoFormerForCausalLM.from_pretrained(pretrained_model, config=config)
43+
model.to(device)
44+
model.eval()
45+
46+
def gen_synonyms(text, n=100, k=20):
47+
''''含义: 产生sent的n个相似句,然后返回最相似的k个。
48+
做法:用seq2seq生成,并用encoder算相似度并排序。
49+
'''
50+
# 寻找所有相似的句子
51+
r = []
52+
inputs1 = tokenizer(text, return_tensors="pt")
53+
for _ in range(n):
54+
inputs1.to(device)
55+
output = tokenizer.batch_decode(model.generate(**inputs1, top_p=0.95, do_sample=True, max_length=128), skip_special_tokens=True)[0].replace(" ","").replace(text, "") # 去除空格,去除原始text文本。
56+
r.append(output)
57+
58+
# 对相似的句子进行排序
59+
r = [i for i in set(r) if i != text and len(i) > 0]
60+
r = [text] + r
61+
inputs2 = tokenizer(r, padding=True, return_tensors="pt")
62+
with torch.no_grad():
63+
inputs2.to(device)
64+
outputs = model(**inputs2)
65+
Z = outputs.pooler_output.cpu().numpy()
66+
Z /= (Z**2).sum(axis=1, keepdims=True)**0.5
67+
argsort = np.dot(Z[1:], -Z[0]).argsort()
68+
69+
return [r[i + 1] for i in argsort[:k]]
70+
71+
out = gen_synonyms("广州和深圳哪个好?")
72+
print(out)
73+
# ['深圳和广州哪个好?',
74+
# '广州和深圳哪个好',
75+
# '深圳和广州哪个好',
76+
# '深圳和广州哪个比较好。',
77+
# '深圳和广州哪个最好?',
78+
# '深圳和广州哪个比较好',
79+
# '广州和深圳那个比较好',
80+
# '深圳和广州哪个更好?',
81+
# '深圳与广州哪个好',
82+
# '深圳和广州,哪个比较好',
83+
# '广州与深圳比较哪个好',
84+
# '深圳和广州哪里比较好',
85+
# '深圳还是广州比较好?',
86+
# '广州和深圳哪个地方好一些?',
87+
# '广州好还是深圳好?',
88+
# '广州好还是深圳好呢?',
89+
# '广州与深圳哪个地方好点?',
90+
# '深圳好还是广州好',
91+
# '广州好还是深圳好',
92+
# '广州和深圳哪个城市好?']
1793
```
1894

1995
## 模型权重对照表
@@ -39,6 +115,8 @@ pip install -U transformers
39115
| [roformer_chinese_sim_char_ft_small](https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_small) | [chinese_roformer-sim-char-ft_L-6_H-384_A-6.zip](https://pan.baidu.com/s/1G36x7YQF1b6nzW0OzyJS_Q) (download code:gty5) |
40116

41117

118+
119+
42120
### 英文模型(使用electra的训练方法在openwebtext上训练的small模型(rotary value = True))
43121
| huggingface.co |
44122
| ---------------------------------- |
@@ -139,7 +217,7 @@ print(tf_outputs_sentence)
139217
# tf: 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。
140218

141219
```
142-
220+
143221
## 手动权重转换
144222
```bash
145223
python convert_roformer_original_tf_checkpoint_to_pytorch.py \

examples/test_sim.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
import numpy as np
3+
from roformer import RoFormerForCausalLM, RoFormerConfig
4+
from transformers import BertTokenizer
5+
6+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
7+
pretrained_model = "junnyu/roformer_chinese_sim_char_base"
8+
tokenizer = BertTokenizer.from_pretrained(pretrained_model)
9+
config = RoFormerConfig.from_pretrained(pretrained_model)
10+
config.is_decoder = True
11+
config.eos_token_id = tokenizer.sep_token_id
12+
config.pooler_activation = "linear"
13+
model = RoFormerForCausalLM.from_pretrained(pretrained_model, config=config)
14+
model.to(device)
15+
model.eval()
16+
17+
def gen_synonyms(text, n=100, k=20):
18+
''''含义: 产生sent的n个相似句,然后返回最相似的k个。
19+
做法:用seq2seq生成,并用encoder算相似度并排序。
20+
'''
21+
# 寻找所有相似的句子
22+
r = []
23+
inputs1 = tokenizer(text, return_tensors="pt")
24+
for _ in range(n):
25+
inputs1.to(device)
26+
output = tokenizer.batch_decode(model.generate(**inputs1, top_p=0.95, do_sample=True, max_length=128), skip_special_tokens=True)[0].replace(" ","").replace(text, "") # 去除空格,去除原始text文本。
27+
r.append(output)
28+
29+
# 对相似的句子进行排序
30+
r = [i for i in set(r) if i != text and len(i) > 0]
31+
r = [text] + r
32+
inputs2 = tokenizer(r, padding=True, return_tensors="pt")
33+
with torch.no_grad():
34+
inputs2.to(device)
35+
outputs = model(**inputs2)
36+
Z = outputs.pooler_output.cpu().numpy()
37+
Z /= (Z**2).sum(axis=1, keepdims=True)**0.5
38+
argsort = np.dot(Z[1:], -Z[0]).argsort()
39+
40+
return [r[i + 1] for i in argsort[:k]]
41+
42+
out = gen_synonyms("广州和深圳哪个好?")
43+
print(out)
44+
# ['深圳和广州哪个好?',
45+
# '广州和深圳哪个好',
46+
# '深圳和广州哪个好',
47+
# '深圳和广州哪个比较好。',
48+
# '深圳和广州哪个最好?',
49+
# '深圳和广州哪个比较好',
50+
# '广州和深圳那个比较好',
51+
# '深圳和广州哪个更好?',
52+
# '深圳与广州哪个好',
53+
# '深圳和广州,哪个比较好',
54+
# '广州与深圳比较哪个好',
55+
# '深圳和广州哪里比较好',
56+
# '深圳还是广州比较好?',
57+
# '广州和深圳哪个地方好一些?',
58+
# '广州好还是深圳好?',
59+
# '广州好还是深圳好呢?',
60+
# '广州与深圳哪个地方好点?',
61+
# '深圳好还是广州好',
62+
# '广州好还是深圳好',
63+
# '广州和深圳哪个城市好?']

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name="roformer",
55
package_dir={"": "src"},
66
packages=find_packages("src"),
7-
version="0.4.0",
7+
version="0.4.1",
88
license="Apache 2.0",
99
description="roformer_pytorch",
1010
author="Jun Yu",

src/roformer/configuration_roformer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,16 @@
2424
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json",
2525
"junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/config.json",
2626
"junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/config.json",
27+
"junnyu/roformer_chinese_sim_char_small": "https://huggingface.co/junnyu/roformer_chinese_sim_char_small/resolve/main/config.json",
28+
"junnyu/roformer_chinese_sim_char_base": "https://huggingface.co/junnyu/roformer_chinese_sim_char_base/resolve/main/config.json",
29+
"junnyu/roformer_chinese_sim_char_ft_base": "https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_base/resolve/main/config.json",
30+
"junnyu/roformer_chinese_sim_char_ft_small": "https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_small/resolve/main/config.json",
2731
"junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/config.json",
2832
"junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/config.json",
33+
"junnyu/roformer_base_wwm_cluecorpussmall": "https://huggingface.co/junnyu/roformer_base_wwm_cluecorpussmall/resolve/main/config.json",
34+
"junnyu/roformer_v2_chinese_char_small": "https://huggingface.co/junnyu/roformer_v2_chinese_char_small/resolve/main/config.json",
35+
"junnyu/roformer_v2_chinese_char_base": "https://huggingface.co/junnyu/roformer_v2_chinese_char_base/resolve/main/config.json",
36+
"junnyu/roformer_v2_chinese_char_large": "https://huggingface.co/junnyu/roformer_v2_chinese_char_large/resolve/main/config.json",
2937
# See all RoFormer models at https://huggingface.co/models?filter=roformer
3038
}
3139

@@ -107,6 +115,7 @@ def __init__(
107115
use_cache=True,
108116
use_bias=True,
109117
norm_type="layer_norm",
118+
pooler_activation="tanh",
110119
**kwargs
111120
):
112121
super().__init__(pad_token_id=pad_token_id, **kwargs)
@@ -128,3 +137,4 @@ def __init__(
128137
self.use_cache = use_cache
129138
self.use_bias = use_bias
130139
self.norm_type = norm_type
140+
self.pooler_activation = pooler_activation

0 commit comments

Comments
 (0)